diff --git a/.travis.yml b/.travis.yml index f29cb960..535bc9ae 100644 --- a/.travis.yml +++ b/.travis.yml @@ -13,7 +13,7 @@ before_install: - sudo dd if=/dev/urandom of=/usr/share/nginx/www/file bs=1M count=10 - sudo sh -c "echo '127.0.0.1 localhost' > /etc/hosts" - sudo service nginx restart - - pip install pep8 pyflakes nose coverage + - pip install pep8 pyflakes nose coverage PySocks - sudo tests/socksify/install.sh - sudo tests/libsodium/install.sh - sudo tests/setup_tc.sh diff --git a/CHANGES b/CHANGES index 4db142a6..a58348a4 100644 --- a/CHANGES +++ b/CHANGES @@ -1,3 +1,21 @@ +2.6.13 2015-11-02 +- add protocol setting + +2.6.12 2015-10-27 +- IPv6 first +- Fix mem leaks +- auth_simple plugin +- remove FORCE_NEW_PROTOCOL +- optimize code + +2.6.11 2015-10-20 +- Obfs plugin +- Obfs parameters +- UDP over TCP +- TCP over UDP (experimental) +- Fix socket leaks +- Catch abnormal UDP package + 2.6.10 2015-06-08 - Optimize LRU cache - Refine logging diff --git a/setup.py b/setup.py index 07ea2db0..689dd736 100644 --- a/setup.py +++ b/setup.py @@ -7,7 +7,7 @@ setup( name="shadowsocks", - version="2.6.11", + version="2.6.12", license='http://www.apache.org/licenses/LICENSE-2.0', description="A fast tunnel proxy that help you get through firewalls", author='clowwindy', diff --git a/shadowsocks/asyncdns.py b/shadowsocks/asyncdns.py index 7e4a4ed7..d958752e 100644 --- a/shadowsocks/asyncdns.py +++ b/shadowsocks/asyncdns.py @@ -18,13 +18,19 @@ from __future__ import absolute_import, division, print_function, \ with_statement -import time import os import socket import struct import re import logging +if __name__ == '__main__': + import sys + import inspect + file_path = os.path.dirname(os.path.realpath(inspect.getfile(inspect.currentframe()))) + os.chdir(file_path) + sys.path.insert(0, os.path.join(file_path, '../')) + from shadowsocks import common, lru_cache, eventloop, shell @@ -72,6 +78,19 @@ QTYPE_NS = 2 QCLASS_IN = 1 +def detect_ipv6_supprot(): + if 'has_ipv6' in dir(socket): + s = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM) + try: + s.connect(('ipv6.google.com', 0)) + print('IPv6 support') + return True + except: + pass + print('IPv6 not support') + return False + +IPV6_CONNECTION_SUPPORT = detect_ipv6_supprot() def build_address(address): address = address.strip(b'.') @@ -256,7 +275,6 @@ def __init__(self): self._hostname_to_cb = {} self._cb_to_hostname = {} self._cache = lru_cache.LRUCache(timeout=300) - self._last_time = time.time() self._sock = None self._servers = None self._parse_resolv() @@ -304,7 +322,7 @@ def _parse_hosts(self): except IOError: self._hosts['localhost'] = '127.0.0.1' - def add_to_loop(self, loop, ref=False): + def add_to_loop(self, loop): if self._loop: raise Exception('already add to loop') self._loop = loop @@ -312,8 +330,8 @@ def add_to_loop(self, loop, ref=False): self._sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, socket.SOL_UDP) self._sock.setblocking(False) - loop.add(self._sock, eventloop.POLL_IN) - loop.add_handler(self.handle_events, ref=ref) + loop.add(self._sock, eventloop.POLL_IN, self) + loop.add_periodic(self.handle_periodic) def _call_callback(self, hostname, ip, error=None): callbacks = self._hostname_to_cb.get(hostname, []) @@ -340,44 +358,56 @@ def _handle_data(self, data): answer[2] == QCLASS_IN: ip = answer[0] break - if not ip and self._hostname_status.get(hostname, STATUS_IPV6) \ - == STATUS_IPV4: - self._hostname_status[hostname] = STATUS_IPV6 - self._send_req(hostname, QTYPE_AAAA) - else: - if ip: - self._cache[hostname] = ip - self._call_callback(hostname, ip) - elif self._hostname_status.get(hostname, None) == STATUS_IPV6: - for question in response.questions: - if question[1] == QTYPE_AAAA: - self._call_callback(hostname, None) - break - - def handle_events(self, events): - for sock, fd, event in events: - if sock != self._sock: - continue - if event & eventloop.POLL_ERR: - logging.error('dns socket err') - self._loop.remove(self._sock) - self._sock.close() - # TODO when dns server is IPv6 - self._sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, - socket.SOL_UDP) - self._sock.setblocking(False) - self._loop.add(self._sock, eventloop.POLL_IN) + if IPV6_CONNECTION_SUPPORT: + if not ip and self._hostname_status.get(hostname, STATUS_IPV4) \ + == STATUS_IPV6: + self._hostname_status[hostname] = STATUS_IPV4 + self._send_req(hostname, QTYPE_A) + else: + if ip: + self._cache[hostname] = ip + self._call_callback(hostname, ip) + elif self._hostname_status.get(hostname, None) == STATUS_IPV4: + for question in response.questions: + if question[1] == QTYPE_A: + self._call_callback(hostname, None) + break else: - data, addr = sock.recvfrom(1024) - if addr[0] not in self._servers: - logging.warn('received a packet other than our dns') - break - self._handle_data(data) - break - now = time.time() - if now - self._last_time > CACHE_SWEEP_INTERVAL: - self._cache.sweep() - self._last_time = now + if not ip and self._hostname_status.get(hostname, STATUS_IPV6) \ + == STATUS_IPV4: + self._hostname_status[hostname] = STATUS_IPV6 + self._send_req(hostname, QTYPE_AAAA) + else: + if ip: + self._cache[hostname] = ip + self._call_callback(hostname, ip) + elif self._hostname_status.get(hostname, None) == STATUS_IPV6: + for question in response.questions: + if question[1] == QTYPE_AAAA: + self._call_callback(hostname, None) + break + + def handle_event(self, sock, fd, event): + if sock != self._sock: + return + if event & eventloop.POLL_ERR: + logging.error('dns socket err') + self._loop.remove(self._sock) + self._sock.close() + # TODO when dns server is IPv6 + self._sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, + socket.SOL_UDP) + self._sock.setblocking(False) + self._loop.add(self._sock, eventloop.POLL_IN, self) + else: + data, addr = sock.recvfrom(1024) + if addr[0] not in self._servers: + logging.warn('received a packet other than our dns') + return + self._handle_data(data) + + def handle_periodic(self): + self._cache.sweep() def remove_callback(self, callback): hostname = self._cb_to_hostname.get(callback) @@ -419,17 +449,27 @@ def resolve(self, hostname, callback): return arr = self._hostname_to_cb.get(hostname, None) if not arr: - self._hostname_status[hostname] = STATUS_IPV4 - self._send_req(hostname, QTYPE_A) + if IPV6_CONNECTION_SUPPORT: + self._hostname_status[hostname] = STATUS_IPV6 + self._send_req(hostname, QTYPE_AAAA) + else: + self._hostname_status[hostname] = STATUS_IPV4 + self._send_req(hostname, QTYPE_A) self._hostname_to_cb[hostname] = [callback] self._cb_to_hostname[callback] = hostname else: arr.append(callback) # TODO send again only if waited too long - self._send_req(hostname, QTYPE_A) + if IPV6_CONNECTION_SUPPORT: + self._send_req(hostname, QTYPE_AAAA) + else: + self._send_req(hostname, QTYPE_A) def close(self): if self._sock: + if self._loop: + self._loop.remove_periodic(self.handle_periodic) + self._loop.remove(self._sock) self._sock.close() self._sock = None @@ -437,7 +477,7 @@ def close(self): def test(): dns_resolver = DNSResolver() loop = eventloop.EventLoop() - dns_resolver.add_to_loop(loop, ref=True) + dns_resolver.add_to_loop(loop) global counter counter = 0 @@ -451,8 +491,8 @@ def callback(result, error): print(result, error) counter += 1 if counter == 9: - loop.remove_handler(dns_resolver.handle_events) dns_resolver.close() + loop.stop() a_callback = callback return a_callback @@ -481,3 +521,4 @@ def callback(result, error): if __name__ == '__main__': test() + diff --git a/shadowsocks/common.py b/shadowsocks/common.py index fc03d556..7f306ea8 100644 --- a/shadowsocks/common.py +++ b/shadowsocks/common.py @@ -21,7 +21,7 @@ import socket import struct import logging - +import binascii def compat_ord(s): if type(s) == int: @@ -54,6 +54,16 @@ def to_str(s): return s.decode('utf-8') return s +def int32(x): + if x > 0xFFFFFFFF or x < 0: + x &= 0xFFFFFFFF + if x > 0x7FFFFFFF: + x = int(0x100000000 - x) + if x < 0x80000000: + return -x + else: + return -2147483648 + return x def inet_ntop(family, ipstr): if family == socket.AF_INET: @@ -138,12 +148,52 @@ def pack_addr(address): address = address[:255] # TODO return b'\x03' + chr(len(address)) + address +def pre_parse_header(data): + datatype = ord(data[0]) + if datatype == 0x80: + if len(data) <= 2: + return None + rand_data_size = ord(data[1]) + if rand_data_size + 2 >= len(data): + logging.warn('header too short, maybe wrong password or ' + 'encryption method') + return None + data = data[rand_data_size + 2:] + elif datatype == 0x81: + data = data[1:] + elif datatype == 0x82: + if len(data) <= 3: + return None + rand_data_size = struct.unpack('>H', data[1:3])[0] + if rand_data_size + 3 >= len(data): + logging.warn('header too short, maybe wrong password or ' + 'encryption method') + return None + data = data[rand_data_size + 3:] + elif datatype == 0x88 or (~datatype & 0xff) == 0x88: + if len(data) <= 7 + 7: + return None + data_size = struct.unpack('>H', data[1:3])[0] + ogn_data = data + data = data[:data_size] + crc = binascii.crc32(data) & 0xffffffff + if crc != 0xffffffff: + logging.warn('uncorrect CRC32, maybe wrong password or ' + 'encryption method') + return None + start_pos = 3 + ord(data[3]) + data = data[start_pos:-4] + if data_size < len(ogn_data): + data += ogn_data[data_size:] + return data def parse_header(data): addrtype = ord(data[0]) dest_addr = None dest_port = None header_length = 0 + connecttype = (addrtype & 0x10) and 1 or 0 + addrtype &= ~0x10 if addrtype == ADDRTYPE_IPV4: if len(data) >= 7: dest_addr = socket.inet_ntoa(data[1:5]) @@ -157,7 +207,7 @@ def parse_header(data): if len(data) >= 2 + addrlen: dest_addr = data[2:2 + addrlen] dest_port = struct.unpack('>H', data[2 + addrlen:4 + - addrlen])[0] + addrlen])[0] header_length = 4 + addrlen else: logging.warn('header is too short') @@ -175,7 +225,7 @@ def parse_header(data): 'encryption method' % addrtype) if dest_addr is None: return None - return addrtype, to_bytes(dest_addr), dest_port, header_length + return connecttype, to_bytes(dest_addr), dest_port, header_length class IPNetwork(object): diff --git a/shadowsocks/crypto/ctypes_libsodium.py b/shadowsocks/crypto/ctypes_libsodium.py new file mode 100644 index 00000000..efecfd41 --- /dev/null +++ b/shadowsocks/crypto/ctypes_libsodium.py @@ -0,0 +1,135 @@ +#!/usr/bin/env python + +# Copyright (c) 2014 clowwindy +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +from __future__ import absolute_import, division, print_function, \ + with_statement + +import logging +from ctypes import CDLL, c_char_p, c_int, c_ulonglong, byref, \ + create_string_buffer, c_void_p + +__all__ = ['ciphers'] + +libsodium = None +loaded = False + +buf_size = 2048 + +# for salsa20 and chacha20 +BLOCK_SIZE = 64 + + +def load_libsodium(): + global loaded, libsodium, buf + + from ctypes.util import find_library + for p in ('sodium',): + libsodium_path = find_library(p) + if libsodium_path: + break + else: + raise Exception('libsodium not found') + logging.info('loading libsodium from %s', libsodium_path) + libsodium = CDLL(libsodium_path) + libsodium.sodium_init.restype = c_int + libsodium.crypto_stream_salsa20_xor_ic.restype = c_int + libsodium.crypto_stream_salsa20_xor_ic.argtypes = (c_void_p, c_char_p, + c_ulonglong, + c_char_p, c_ulonglong, + c_char_p) + libsodium.crypto_stream_chacha20_xor_ic.restype = c_int + libsodium.crypto_stream_chacha20_xor_ic.argtypes = (c_void_p, c_char_p, + c_ulonglong, + c_char_p, c_ulonglong, + c_char_p) + + libsodium.sodium_init() + + buf = create_string_buffer(buf_size) + loaded = True + + +class Salsa20Crypto(object): + def __init__(self, cipher_name, key, iv, op): + if not loaded: + load_libsodium() + self.key = key + self.iv = iv + self.key_ptr = c_char_p(key) + self.iv_ptr = c_char_p(iv) + if cipher_name == b'salsa20': + self.cipher = libsodium.crypto_stream_salsa20_xor_ic + elif cipher_name == b'chacha20': + self.cipher = libsodium.crypto_stream_chacha20_xor_ic + else: + raise Exception('Unknown cipher') + # byte counter, not block counter + self.counter = 0 + + def update(self, data): + global buf_size, buf + l = len(data) + + # we can only prepend some padding to make the encryption align to + # blocks + padding = self.counter % BLOCK_SIZE + if buf_size < padding + l: + buf_size = (padding + l) * 2 + buf = create_string_buffer(buf_size) + + if padding: + data = (b'\0' * padding) + data + self.cipher(byref(buf), c_char_p(data), padding + l, + self.iv_ptr, int(self.counter / BLOCK_SIZE), self.key_ptr) + self.counter += l + # buf is copied to a str object when we access buf.raw + # strip off the padding + return buf.raw[padding:padding + l] + + +ciphers = { + b'salsa20': (32, 8, Salsa20Crypto), + b'chacha20': (32, 8, Salsa20Crypto), +} + + +def test_salsa20(): + from shadowsocks.crypto import util + + cipher = Salsa20Crypto(b'salsa20', b'k' * 32, b'i' * 16, 1) + decipher = Salsa20Crypto(b'salsa20', b'k' * 32, b'i' * 16, 0) + + util.run_cipher(cipher, decipher) + + +def test_chacha20(): + from shadowsocks.crypto import util + + cipher = Salsa20Crypto(b'chacha20', b'k' * 32, b'i' * 16, 1) + decipher = Salsa20Crypto(b'chacha20', b'k' * 32, b'i' * 16, 0) + + util.run_cipher(cipher, decipher) + + +if __name__ == '__main__': + test_chacha20() + test_salsa20() diff --git a/shadowsocks/crypto/ctypes_openssl.py b/shadowsocks/crypto/ctypes_openssl.py new file mode 100644 index 00000000..0ef8ce0f --- /dev/null +++ b/shadowsocks/crypto/ctypes_openssl.py @@ -0,0 +1,188 @@ +#!/usr/bin/env python + +# Copyright (c) 2014 clowwindy +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +from __future__ import absolute_import, division, print_function, \ + with_statement + +import logging +from ctypes import CDLL, c_char_p, c_int, c_long, byref,\ + create_string_buffer, c_void_p + +__all__ = ['ciphers'] + +libcrypto = None +loaded = False + +buf_size = 2048 + + +def load_openssl(): + global loaded, libcrypto, buf + + from ctypes.util import find_library + for p in ('crypto', 'eay32', 'libeay32'): + libcrypto_path = find_library(p) + if libcrypto_path: + break + else: + raise Exception('libcrypto(OpenSSL) not found') + logging.info('loading libcrypto from %s', libcrypto_path) + libcrypto = CDLL(libcrypto_path) + libcrypto.EVP_get_cipherbyname.restype = c_void_p + libcrypto.EVP_CIPHER_CTX_new.restype = c_void_p + + libcrypto.EVP_CipherInit_ex.argtypes = (c_void_p, c_void_p, c_char_p, + c_char_p, c_char_p, c_int) + + libcrypto.EVP_CipherUpdate.argtypes = (c_void_p, c_void_p, c_void_p, + c_char_p, c_int) + + libcrypto.EVP_CIPHER_CTX_cleanup.argtypes = (c_void_p,) + libcrypto.EVP_CIPHER_CTX_free.argtypes = (c_void_p,) + if hasattr(libcrypto, 'OpenSSL_add_all_ciphers'): + libcrypto.OpenSSL_add_all_ciphers() + + buf = create_string_buffer(buf_size) + loaded = True + + +def load_cipher(cipher_name): + func_name = b'EVP_' + cipher_name.replace(b'-', b'_') + if bytes != str: + func_name = str(func_name, 'utf-8') + cipher = getattr(libcrypto, func_name, None) + if cipher: + cipher.restype = c_void_p + return cipher() + return None + + +class CtypesCrypto(object): + def __init__(self, cipher_name, key, iv, op): + if not loaded: + load_openssl() + self._ctx = None + cipher = libcrypto.EVP_get_cipherbyname(cipher_name) + if not cipher: + cipher = load_cipher(cipher_name) + if not cipher: + raise Exception('cipher %s not found in libcrypto' % cipher_name) + key_ptr = c_char_p(key) + iv_ptr = c_char_p(iv) + self._ctx = libcrypto.EVP_CIPHER_CTX_new() + if not self._ctx: + raise Exception('can not create cipher context') + r = libcrypto.EVP_CipherInit_ex(self._ctx, cipher, None, + key_ptr, iv_ptr, c_int(op)) + if not r: + self.clean() + raise Exception('can not initialize cipher context') + + def update(self, data): + global buf_size, buf + cipher_out_len = c_long(0) + l = len(data) + if buf_size < l: + buf_size = l * 2 + buf = create_string_buffer(buf_size) + libcrypto.EVP_CipherUpdate(self._ctx, byref(buf), + byref(cipher_out_len), c_char_p(data), l) + # buf is copied to a str object when we access buf.raw + return buf.raw[:cipher_out_len.value] + + def __del__(self): + self.clean() + + def clean(self): + if self._ctx: + libcrypto.EVP_CIPHER_CTX_cleanup(self._ctx) + libcrypto.EVP_CIPHER_CTX_free(self._ctx) + + +ciphers = { + b'aes-128-cfb': (16, 16, CtypesCrypto), + b'aes-192-cfb': (24, 16, CtypesCrypto), + b'aes-256-cfb': (32, 16, CtypesCrypto), + b'aes-128-ofb': (16, 16, CtypesCrypto), + b'aes-192-ofb': (24, 16, CtypesCrypto), + b'aes-256-ofb': (32, 16, CtypesCrypto), + b'aes-128-ctr': (16, 16, CtypesCrypto), + b'aes-192-ctr': (24, 16, CtypesCrypto), + b'aes-256-ctr': (32, 16, CtypesCrypto), + b'aes-128-cfb8': (16, 16, CtypesCrypto), + b'aes-192-cfb8': (24, 16, CtypesCrypto), + b'aes-256-cfb8': (32, 16, CtypesCrypto), + b'aes-128-cfb1': (16, 16, CtypesCrypto), + b'aes-192-cfb1': (24, 16, CtypesCrypto), + b'aes-256-cfb1': (32, 16, CtypesCrypto), + b'bf-cfb': (16, 8, CtypesCrypto), + b'camellia-128-cfb': (16, 16, CtypesCrypto), + b'camellia-192-cfb': (24, 16, CtypesCrypto), + b'camellia-256-cfb': (32, 16, CtypesCrypto), + b'cast5-cfb': (16, 8, CtypesCrypto), + b'des-cfb': (8, 8, CtypesCrypto), + b'idea-cfb': (16, 8, CtypesCrypto), + b'rc2-cfb': (16, 8, CtypesCrypto), + b'rc4': (16, 0, CtypesCrypto), + b'seed-cfb': (16, 16, CtypesCrypto), +} + + +def run_method(method): + from shadowsocks.crypto import util + + cipher = CtypesCrypto(method, b'k' * 32, b'i' * 16, 1) + decipher = CtypesCrypto(method, b'k' * 32, b'i' * 16, 0) + + util.run_cipher(cipher, decipher) + + +def test_aes_128_cfb(): + run_method(b'aes-128-cfb') + + +def test_aes_256_cfb(): + run_method(b'aes-256-cfb') + + +def test_aes_128_cfb8(): + run_method(b'aes-128-cfb8') + + +def test_aes_256_ofb(): + run_method(b'aes-256-ofb') + + +def test_aes_256_ctr(): + run_method(b'aes-256-ctr') + + +def test_bf_cfb(): + run_method(b'bf-cfb') + + +def test_rc4(): + run_method(b'rc4') + + +if __name__ == '__main__': + test_aes_128_cfb() diff --git a/shadowsocks/crypto/m2.py b/shadowsocks/crypto/m2.py new file mode 100644 index 00000000..4c7e1480 --- /dev/null +++ b/shadowsocks/crypto/m2.py @@ -0,0 +1,117 @@ +#!/usr/bin/env python + +# Copyright (c) 2014 clowwindy +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +from __future__ import absolute_import, division, print_function, \ + with_statement + +import sys +import logging + +__all__ = ['ciphers'] + +has_m2 = True +try: + __import__('M2Crypto') +except ImportError: + has_m2 = False + + +def create_cipher(alg, key, iv, op, key_as_bytes=0, d=None, salt=None, i=1, + padding=1): + + import M2Crypto.EVP + return M2Crypto.EVP.Cipher(alg.replace('-', '_'), key, iv, op, + key_as_bytes=0, d='md5', salt=None, i=1, + padding=1) + + +def err(alg, key, iv, op, key_as_bytes=0, d=None, salt=None, i=1, padding=1): + logging.error(('M2Crypto is required to use %s, please run' + ' `apt-get install python-m2crypto`') % alg) + sys.exit(1) + + +if has_m2: + ciphers = { + b'aes-128-cfb': (16, 16, create_cipher), + b'aes-192-cfb': (24, 16, create_cipher), + b'aes-256-cfb': (32, 16, create_cipher), + b'bf-cfb': (16, 8, create_cipher), + b'camellia-128-cfb': (16, 16, create_cipher), + b'camellia-192-cfb': (24, 16, create_cipher), + b'camellia-256-cfb': (32, 16, create_cipher), + b'cast5-cfb': (16, 8, create_cipher), + b'des-cfb': (8, 8, create_cipher), + b'idea-cfb': (16, 8, create_cipher), + b'rc2-cfb': (16, 8, create_cipher), + b'rc4': (16, 0, create_cipher), + b'seed-cfb': (16, 16, create_cipher), + } +else: + ciphers = {} + + +def run_method(method): + from shadowsocks.crypto import util + + cipher = create_cipher(method, b'k' * 32, b'i' * 16, 1) + decipher = create_cipher(method, b'k' * 32, b'i' * 16, 0) + + util.run_cipher(cipher, decipher) + + +def check_env(): + # skip this test on pypy and Python 3 + try: + import __pypy__ + del __pypy__ + from nose.plugins.skip import SkipTest + raise SkipTest + except ImportError: + pass + if bytes != str: + from nose.plugins.skip import SkipTest + raise SkipTest + + +def test_aes_128_cfb(): + check_env() + run_method(b'aes-128-cfb') + + +def test_aes_256_cfb(): + check_env() + run_method(b'aes-256-cfb') + + +def test_bf_cfb(): + check_env() + run_method(b'bf-cfb') + + +def test_rc4(): + check_env() + run_method(b'rc4') + + +if __name__ == '__main__': + test_aes_128_cfb() diff --git a/shadowsocks/crypto/util.py b/shadowsocks/crypto/util.py index e579455e..212df860 100644 --- a/shadowsocks/crypto/util.py +++ b/shadowsocks/crypto/util.py @@ -88,7 +88,8 @@ def find_library(possible_lib_names, search_symbol, library_name): logging.warn('can\'t find symbol %s in %s', search_symbol, path) except Exception: - pass + if path == paths[-1]: + raise return None diff --git a/shadowsocks/encrypt.py b/shadowsocks/encrypt.py index 4e87f415..d3b27527 100644 --- a/shadowsocks/encrypt.py +++ b/shadowsocks/encrypt.py @@ -47,6 +47,8 @@ def try_cipher(key, method=None): def EVP_BytesToKey(password, key_len, iv_len): # equivalent to OpenSSL's EVP_BytesToKey() with count 1 # so that we make the same key and iv as nodejs version + if hasattr(password, 'encode'): + password = password.encode('utf-8') cached_key = '%s-%d-%d' % (password, key_len, iv_len) r = cached_keys.get(cached_key, None) if r: @@ -75,6 +77,7 @@ def __init__(self, key, method): self.iv = None self.iv_sent = False self.cipher_iv = b'' + self.iv_buf = b'' self.decipher = None method = method.lower() self._method_info = self.get_method_info(method) @@ -120,16 +123,21 @@ def encrypt(self, buf): def decrypt(self, buf): if len(buf) == 0: return buf - if self.decipher is None: - decipher_iv_len = self._method_info[1] - decipher_iv = buf[:decipher_iv_len] + if self.decipher is not None: #optimize + return self.decipher.update(buf) + + decipher_iv_len = self._method_info[1] + if len(self.iv_buf) <= decipher_iv_len: + self.iv_buf += buf + if len(self.iv_buf) > decipher_iv_len: + decipher_iv = self.iv_buf[:decipher_iv_len] self.decipher = self.get_cipher(self.key, self.method, 0, iv=decipher_iv) - buf = buf[decipher_iv_len:] - if len(buf) == 0: - return buf - return self.decipher.update(buf) - + buf = self.iv_buf[decipher_iv_len:] + del self.iv_buf + return self.decipher.update(buf) + else: + return b'' def encrypt_all(password, method, op, data): result = [] diff --git a/shadowsocks/encrypt_test.py b/shadowsocks/encrypt_test.py new file mode 100644 index 00000000..68228e18 --- /dev/null +++ b/shadowsocks/encrypt_test.py @@ -0,0 +1,33 @@ +from __future__ import absolute_import, division, print_function, \ + with_statement + +import sys +import os + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../')) + + +from crypto import rc4_md5 +from crypto import openssl +from crypto import sodium +from crypto import table + +def main(): + print("\n""rc4_md5") + rc4_md5.test() + print("\n""aes-256-cfb") + openssl.test_aes_256_cfb() + print("\n""aes-128-cfb") + openssl.test_aes_128_cfb() + print("\n""rc4") + openssl.test_rc4() + print("\n""salsa20") + sodium.test_salsa20() + print("\n""chacha20") + sodium.test_chacha20() + print("\n""table") + table.test_encryption() + +if __name__ == '__main__': + main() + diff --git a/shadowsocks/eventloop.py b/shadowsocks/eventloop.py index 42f9205b..ce9c11bc 100644 --- a/shadowsocks/eventloop.py +++ b/shadowsocks/eventloop.py @@ -22,6 +22,7 @@ with_statement import os +import time import socket import select import errno @@ -51,23 +52,8 @@ POLL_NVAL: 'POLL_NVAL', } - -class EpollLoop(object): - - def __init__(self): - self._epoll = select.epoll() - - def poll(self, timeout): - return self._epoll.poll(timeout) - - def add_fd(self, fd, mode): - self._epoll.register(fd, mode) - - def remove_fd(self, fd): - self._epoll.unregister(fd) - - def modify_fd(self, fd, mode): - self._epoll.modify(fd, mode) +# we check timeouts every TIMEOUT_PRECISION seconds +TIMEOUT_PRECISION = 10 class KqueueLoop(object): @@ -100,17 +86,20 @@ def poll(self, timeout): results[fd] |= POLL_OUT return results.items() - def add_fd(self, fd, mode): + def register(self, fd, mode): self._fds[fd] = mode self._control(fd, mode, select.KQ_EV_ADD) - def remove_fd(self, fd): + def unregister(self, fd): self._control(fd, self._fds[fd], select.KQ_EV_DELETE) del self._fds[fd] - def modify_fd(self, fd, mode): - self.remove_fd(fd) - self.add_fd(fd, mode) + def modify(self, fd, mode): + self.unregister(fd) + self.register(fd, mode) + + def close(self): + self.kqueue.close() class SelectLoop(object): @@ -129,7 +118,7 @@ def poll(self, timeout): results[fd] |= p[1] return results.items() - def add_fd(self, fd, mode): + def register(self, fd, mode): if mode & POLL_IN: self._r_list.add(fd) if mode & POLL_OUT: @@ -137,7 +126,7 @@ def add_fd(self, fd, mode): if mode & POLL_ERR: self._x_list.add(fd) - def remove_fd(self, fd): + def unregister(self, fd): if fd in self._r_list: self._r_list.remove(fd) if fd in self._w_list: @@ -145,16 +134,18 @@ def remove_fd(self, fd): if fd in self._x_list: self._x_list.remove(fd) - def modify_fd(self, fd, mode): - self.remove_fd(fd) - self.add_fd(fd, mode) + def modify(self, fd, mode): + self.unregister(fd) + self.register(fd, mode) + + def close(self): + pass class EventLoop(object): def __init__(self): - self._iterating = False if hasattr(select, 'epoll'): - self._impl = EpollLoop() + self._impl = select.epoll() model = 'epoll' elif hasattr(select, 'kqueue'): self._impl = KqueueLoop() @@ -165,72 +156,74 @@ def __init__(self): else: raise Exception('can not find any available functions in select ' 'package') - self._fd_to_f = {} - self._handlers = [] - self._ref_handlers = [] - self._handlers_to_remove = [] + self._fdmap = {} # (f, handler) + self._last_time = time.time() + self._periodic_callbacks = [] + self._stopping = False logging.debug('using event model: %s', model) def poll(self, timeout=None): events = self._impl.poll(timeout) - return [(self._fd_to_f[fd], fd, event) for fd, event in events] + return [(self._fdmap[fd][0], fd, event) for fd, event in events] - def add(self, f, mode): + def add(self, f, mode, handler): fd = f.fileno() - self._fd_to_f[fd] = f - self._impl.add_fd(fd, mode) + self._fdmap[fd] = (f, handler) + self._impl.register(fd, mode) def remove(self, f): fd = f.fileno() - del self._fd_to_f[fd] - self._impl.remove_fd(fd) + del self._fdmap[fd] + self._impl.unregister(fd) + + def add_periodic(self, callback): + self._periodic_callbacks.append(callback) + + def remove_periodic(self, callback): + self._periodic_callbacks.remove(callback) def modify(self, f, mode): fd = f.fileno() - self._impl.modify_fd(fd, mode) - - def add_handler(self, handler, ref=True): - self._handlers.append(handler) - if ref: - # when all ref handlers are removed, loop stops - self._ref_handlers.append(handler) - - def remove_handler(self, handler): - if handler in self._ref_handlers: - self._ref_handlers.remove(handler) - if self._iterating: - self._handlers_to_remove.append(handler) - else: - self._handlers.remove(handler) + self._impl.modify(fd, mode) + + def stop(self): + self._stopping = True def run(self): events = [] - while self._ref_handlers: + while not self._stopping: + asap = False try: - events = self.poll(1) + events = self.poll(TIMEOUT_PRECISION) except (OSError, IOError) as e: if errno_from_exception(e) in (errno.EPIPE, errno.EINTR): # EPIPE: Happens when the client closes the connection # EINTR: Happens when received a signal # handles them as soon as possible + asap = True logging.debug('poll:%s', e) else: logging.error('poll:%s', e) import traceback traceback.print_exc() continue - self._iterating = True - for handler in self._handlers: - # TODO when there are a lot of handlers - try: - handler(events) - except (OSError, IOError) as e: - shell.print_exception(e) - if self._handlers_to_remove: - for handler in self._handlers_to_remove: - self._handlers.remove(handler) - self._handlers_to_remove = [] - self._iterating = False + + for sock, fd, event in events: + handler = self._fdmap.get(fd, None) + if handler is not None: + handler = handler[1] + try: + handler.handle_event(sock, fd, event) + except (OSError, IOError) as e: + shell.print_exception(e) + now = time.time() + if asap or now - self._last_time >= TIMEOUT_PRECISION: + for callback in self._periodic_callbacks: + callback() + self._last_time = now + + def __del__(self): + self._impl.close() # from tornado diff --git a/shadowsocks/local.py b/shadowsocks/local.py index 4255a2ee..096283c1 100755 --- a/shadowsocks/local.py +++ b/shadowsocks/local.py @@ -23,7 +23,12 @@ import logging import signal -sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../')) +if __name__ == '__main__': + import inspect + file_path = os.path.dirname(os.path.realpath(inspect.getfile(inspect.currentframe()))) + os.chdir(file_path) + sys.path.insert(0, os.path.join(file_path, '../')) + from shadowsocks import shell, daemon, eventloop, tcprelay, udprelay, asyncdns @@ -38,6 +43,9 @@ def main(): config = shell.get_config(True) + if not config.get('dns_ipv6', False): + asyncdns.IPV6_CONNECTION_SUPPORT = False + daemon.daemon_exec(config) try: diff --git a/shadowsocks/lru_cache.py b/shadowsocks/lru_cache.py index 401f19b5..e67fdffe 100644 --- a/shadowsocks/lru_cache.py +++ b/shadowsocks/lru_cache.py @@ -88,12 +88,12 @@ def sweep(self): self.close_callback(value) self._closed_values.add(value) for key in self._time_to_keys[least]: - self._last_visits.popleft() if key in self._store: if now - self._keys_to_last_time[key] > self.timeout: del self._store[key] del self._keys_to_last_time[key] c += 1 + self._last_visits.popleft() del self._time_to_keys[least] if c: self._closed_values.clear() diff --git a/shadowsocks/manager.py b/shadowsocks/manager.py new file mode 100644 index 00000000..bfabd7fa --- /dev/null +++ b/shadowsocks/manager.py @@ -0,0 +1,288 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- +# +# Copyright 2015 clowwindy +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from __future__ import absolute_import, division, print_function, \ + with_statement + +import errno +import traceback +import socket +import logging +import json +import collections + +from shadowsocks import common, eventloop, tcprelay, udprelay, asyncdns, shell + + +BUF_SIZE = 1506 +STAT_SEND_LIMIT = 50 + + +class Manager(object): + + def __init__(self, config): + self._config = config + self._relays = {} # (tcprelay, udprelay) + self._loop = eventloop.EventLoop() + self._dns_resolver = asyncdns.DNSResolver() + self._dns_resolver.add_to_loop(self._loop) + + self._statistics = collections.defaultdict(int) + self._control_client_addr = None + try: + manager_address = common.to_str(config['manager_address']) + if ':' in manager_address: + addr = manager_address.rsplit(':', 1) + addr = addr[0], int(addr[1]) + addrs = socket.getaddrinfo(addr[0], addr[1]) + if addrs: + family = addrs[0][0] + else: + logging.error('invalid address: %s', manager_address) + exit(1) + else: + addr = manager_address + family = socket.AF_UNIX + self._control_socket = socket.socket(family, + socket.SOCK_DGRAM) + self._control_socket.bind(addr) + self._control_socket.setblocking(False) + except (OSError, IOError) as e: + logging.error(e) + logging.error('can not bind to manager address') + exit(1) + self._loop.add(self._control_socket, + eventloop.POLL_IN, self) + self._loop.add_periodic(self.handle_periodic) + + port_password = config['port_password'] + del config['port_password'] + for port, password in port_password.items(): + a_config = config.copy() + a_config['server_port'] = int(port) + a_config['password'] = password + self.add_port(a_config) + + def add_port(self, config): + port = int(config['server_port']) + servers = self._relays.get(port, None) + if servers: + logging.error("server already exists at %s:%d" % (config['server'], + port)) + return + logging.info("adding server at %s:%d" % (config['server'], port)) + t = tcprelay.TCPRelay(config, self._dns_resolver, False, + self.stat_callback) + u = udprelay.UDPRelay(config, self._dns_resolver, False, + self.stat_callback) + t.add_to_loop(self._loop) + u.add_to_loop(self._loop) + self._relays[port] = (t, u) + + def remove_port(self, config): + port = int(config['server_port']) + servers = self._relays.get(port, None) + if servers: + logging.info("removing server at %s:%d" % (config['server'], port)) + t, u = servers + t.close(next_tick=False) + u.close(next_tick=False) + del self._relays[port] + else: + logging.error("server not exist at %s:%d" % (config['server'], + port)) + + def handle_event(self, sock, fd, event): + if sock == self._control_socket and event == eventloop.POLL_IN: + data, self._control_client_addr = sock.recvfrom(BUF_SIZE) + parsed = self._parse_command(data) + if parsed: + command, config = parsed + a_config = self._config.copy() + if config: + # let the command override the configuration file + a_config.update(config) + if 'server_port' not in a_config: + logging.error('can not find server_port in config') + else: + if command == 'add': + self.add_port(a_config) + self._send_control_data(b'ok') + elif command == 'remove': + self.remove_port(a_config) + self._send_control_data(b'ok') + elif command == 'ping': + self._send_control_data(b'pong') + else: + logging.error('unknown command %s', command) + + def _parse_command(self, data): + # commands: + # add: {"server_port": 8000, "password": "foobar"} + # remove: {"server_port": 8000"} + data = common.to_str(data) + parts = data.split(':', 1) + if len(parts) < 2: + return data, None + command, config_json = parts + try: + config = shell.parse_json_in_str(config_json) + return command, config + except Exception as e: + logging.error(e) + return None + + def stat_callback(self, port, data_len): + self._statistics[port] += data_len + + def handle_periodic(self): + r = {} + i = 0 + + def send_data(data_dict): + if data_dict: + # use compact JSON format (without space) + data = common.to_bytes(json.dumps(data_dict, + separators=(',', ':'))) + self._send_control_data(b'stat: ' + data) + + for k, v in self._statistics.items(): + r[k] = v + i += 1 + # split the data into segments that fit in UDP packets + if i >= STAT_SEND_LIMIT: + send_data(r) + r.clear() + i = 0 + if len(r) > 0 : + send_data(r) + self._statistics.clear() + + def _send_control_data(self, data): + if self._control_client_addr: + try: + self._control_socket.sendto(data, self._control_client_addr) + except (socket.error, OSError, IOError) as e: + error_no = eventloop.errno_from_exception(e) + if error_no in (errno.EAGAIN, errno.EINPROGRESS, + errno.EWOULDBLOCK): + return + else: + shell.print_exception(e) + if self._config['verbose']: + traceback.print_exc() + + def run(self): + self._loop.run() + + +def run(config): + Manager(config).run() + + +def test(): + import time + import threading + import struct + from shadowsocks import encrypt + + logging.basicConfig(level=5, + format='%(asctime)s %(levelname)-8s %(message)s', + datefmt='%Y-%m-%d %H:%M:%S') + enc = [] + eventloop.TIMEOUT_PRECISION = 1 + + def run_server(): + config = { + 'server': '127.0.0.1', + 'local_port': 1081, + 'port_password': { + '8381': 'foobar1', + '8382': 'foobar2' + }, + 'method': 'aes-256-cfb', + 'manager_address': '127.0.0.1:6001', + 'timeout': 60, + 'fast_open': False, + 'verbose': 2 + } + manager = Manager(config) + enc.append(manager) + manager.run() + + t = threading.Thread(target=run_server) + t.start() + time.sleep(1) + manager = enc[0] + cli = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + cli.connect(('127.0.0.1', 6001)) + + # test add and remove + time.sleep(1) + cli.send(b'add: {"server_port":7001, "password":"asdfadsfasdf"}') + time.sleep(1) + assert 7001 in manager._relays + data, addr = cli.recvfrom(1506) + assert b'ok' in data + + cli.send(b'remove: {"server_port":8381}') + time.sleep(1) + assert 8381 not in manager._relays + data, addr = cli.recvfrom(1506) + assert b'ok' in data + logging.info('add and remove test passed') + + # test statistics for TCP + header = common.pack_addr(b'google.com') + struct.pack('>H', 80) + data = encrypt.encrypt_all(b'asdfadsfasdf', 'aes-256-cfb', 1, + header + b'GET /\r\n\r\n') + tcp_cli = socket.socket() + tcp_cli.connect(('127.0.0.1', 7001)) + tcp_cli.send(data) + tcp_cli.recv(4096) + tcp_cli.close() + + data, addr = cli.recvfrom(1506) + data = common.to_str(data) + assert data.startswith('stat: ') + data = data.split('stat:')[1] + stats = shell.parse_json_in_str(data) + assert '7001' in stats + logging.info('TCP statistics test passed') + + # test statistics for UDP + header = common.pack_addr(b'127.0.0.1') + struct.pack('>H', 80) + data = encrypt.encrypt_all(b'foobar2', 'aes-256-cfb', 1, + header + b'test') + udp_cli = socket.socket(type=socket.SOCK_DGRAM) + udp_cli.sendto(data, ('127.0.0.1', 8382)) + tcp_cli.close() + + data, addr = cli.recvfrom(1506) + data = common.to_str(data) + assert data.startswith('stat: ') + data = data.split('stat:')[1] + stats = json.loads(data) + assert '8382' in stats + logging.info('UDP statistics test passed') + + manager._loop.stop() + t.join() + + +if __name__ == '__main__': + test() diff --git a/shadowsocks/obfs.py b/shadowsocks/obfs.py new file mode 100644 index 00000000..1752a56e --- /dev/null +++ b/shadowsocks/obfs.py @@ -0,0 +1,89 @@ +#!/usr/bin/env python +# +# Copyright 2015-2015 breakwa11 +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from __future__ import absolute_import, division, print_function, \ + with_statement + +import os +import sys +import hashlib +import logging + +from shadowsocks import common +from shadowsocks.obfsplugin import plain, http_simple, verify_simple + + +method_supported = {} +method_supported.update(plain.obfs_map) +method_supported.update(http_simple.obfs_map) +method_supported.update(verify_simple.obfs_map) + +class server_info(object): + def __init__(self, data): + self.data = data + +class obfs(object): + def __init__(self, method): + self.method = method + self._method_info = self.get_method_info(method) + if self._method_info: + self.obfs = self.get_obfs(method) + else: + raise Exception('method %s not supported' % method) + + def init_data(self): + return self.obfs.init_data() + + def set_server_info(self, server_info): + return self.obfs.set_server_info(server_info) + + def get_method_info(self, method): + method = method.lower() + m = method_supported.get(method) + return m + + def get_obfs(self, method): + m = self._method_info + return m[0](method) + + def client_pre_encrypt(self, buf): + return self.obfs.client_pre_encrypt(buf) + + def client_encode(self, buf): + return self.obfs.client_encode(buf) + + def client_decode(self, buf): + return self.obfs.client_decode(buf) + + def client_post_decrypt(self, buf): + return self.obfs.client_post_decrypt(buf) + + def server_pre_encrypt(self, buf): + return self.obfs.server_pre_encrypt(buf) + + def server_encode(self, buf): + return self.obfs.server_encode(buf) + + def server_decode(self, buf): + return self.obfs.server_decode(buf) + + def server_post_decrypt(self, buf): + return self.obfs.server_post_decrypt(buf) + + def dispose(self): + self.obfs.dispose() + del self.obfs + diff --git a/shadowsocks/obfsplugin/__init__.py b/shadowsocks/obfsplugin/__init__.py new file mode 100644 index 00000000..401c7b72 --- /dev/null +++ b/shadowsocks/obfsplugin/__init__.py @@ -0,0 +1,18 @@ +#!/usr/bin/env python +# +# Copyright 2015 clowwindy +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from __future__ import absolute_import, division, print_function, \ + with_statement diff --git a/shadowsocks/obfsplugin/http_simple.py b/shadowsocks/obfsplugin/http_simple.py new file mode 100644 index 00000000..8144aa80 --- /dev/null +++ b/shadowsocks/obfsplugin/http_simple.py @@ -0,0 +1,366 @@ +#!/usr/bin/env python +# +# Copyright 2015-2015 breakwa11 +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from __future__ import absolute_import, division, print_function, \ + with_statement + +import os +import sys +import hashlib +import logging +import binascii +import struct +import base64 +import datetime +import random + +from shadowsocks import common +from shadowsocks.obfsplugin import plain +from shadowsocks.common import to_bytes, to_str, ord + +def create_http_obfs(method): + return http_simple(method) + +def create_http2_obfs(method): + return http2_simple(method) + +def create_tls_obfs(method): + return tls_simple(method) + +def create_random_head_obfs(method): + return random_head(method) + +obfs_map = { + 'http_simple': (create_http_obfs,), + 'http_simple_compatible': (create_http_obfs,), + 'http2_simple': (create_http2_obfs,), + 'http2_simple_compatible': (create_http2_obfs,), + 'tls_simple': (create_tls_obfs,), + 'tls_simple_compatible': (create_tls_obfs,), + 'random_head': (create_random_head_obfs,), + 'random_head_compatible': (create_random_head_obfs,), +} + +def match_begin(str1, str2): + if len(str1) >= len(str2): + if str1[:len(str2)] == str2: + return True + return False + +class http_simple(plain.plain): + def __init__(self, method): + self.method = method + self.has_sent_header = False + self.has_recv_header = False + self.host = None + self.port = 0 + self.recv_buffer = b'' + self.user_agent = [b"Mozilla/5.0 (Windows NT 6.3; WOW64; rv:40.0) Gecko/20100101 Firefox/40.0", + b"Mozilla/5.0 (Windows NT 6.3; WOW64; rv:40.0) Gecko/20100101 Firefox/44.0", + b"Mozilla/5.0 (Windows NT 6.1) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/41.0.2228.0 Safari/537.36", + b"Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/535.11 (KHTML, like Gecko) Ubuntu/11.10 Chromium/27.0.1453.93 Chrome/27.0.1453.93 Safari/537.36", + b"Mozilla/5.0 (X11; Ubuntu; Linux x86_64; rv:35.0) Gecko/20100101 Firefox/35.0", + b"Mozilla/5.0 (compatible; WOW64; MSIE 10.0; Windows NT 6.2)", + b"Mozilla/5.0 (Windows; U; Windows NT 6.1; en-US) AppleWebKit/533.20.25 (KHTML, like Gecko) Version/5.0.4 Safari/533.20.27", + b"Mozilla/4.0 (compatible; MSIE 7.0; Windows NT 6.3; Trident/7.0; .NET4.0E; .NET4.0C)", + b"Mozilla/5.0 (Windows NT 6.3; Trident/7.0; rv:11.0) like Gecko", + b"Mozilla/5.0 (Linux; Android 4.4; Nexus 5 Build/BuildID) AppleWebKit/537.36 (KHTML, like Gecko) Version/4.0 Chrome/30.0.0.0 Mobile Safari/537.36", + b"Mozilla/5.0 (iPad; CPU OS 5_0 like Mac OS X) AppleWebKit/534.46 (KHTML, like Gecko) Version/5.1 Mobile/9A334 Safari/7534.48.3", + b"Mozilla/5.0 (iPhone; CPU iPhone OS 5_0 like Mac OS X) AppleWebKit/534.46 (KHTML, like Gecko) Version/5.1 Mobile/9A334 Safari/7534.48.3"] + + def encode_head(self, buf): + ret = b'' + for ch in buf: + ret += '%' + binascii.hexlify(ch) + return ret + + def client_encode(self, buf): + if self.has_sent_header: + return buf + if len(buf) > 64: + headlen = random.randint(1, 64) + else: + headlen = len(buf) + headdata = buf[:headlen] + buf = buf[headlen:] + port = b'' + if self.server_info.port != 80: + port = b':' + common.to_bytes(str(self.server_info.port)) + http_head = b"GET /" + self.encode_head(headdata) + b" HTTP/1.1\r\n" + http_head += b"Host: " + (self.server_info.param or self.server_info.host) + port + b"\r\n" + http_head += b"User-Agent: " + random.choice(self.user_agent) + b"\r\n" + http_head += b"Accept: text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8\r\nAccept-Language: en-US,en;q=0.8\r\nAccept-Encoding: gzip, deflate\r\nDNT: 1\r\nConnection: keep-alive\r\n\r\n" + self.has_sent_header = True + return http_head + buf + + def client_decode(self, buf): + if self.has_recv_header: + return (buf, False) + pos = buf.find(b'\r\n\r\n') + if pos >= 0: + self.has_recv_header = True + return (buf[pos + 4:], False) + else: + return (b'', False) + + def server_encode(self, buf): + if self.has_sent_header: + return buf + + header = b'HTTP/1.1 200 OK\r\nServer: openresty\r\nDate: ' + header += to_bytes(datetime.datetime.now().strftime('%a, %d %b %Y %H:%M:%S GMT')) + header += b'\r\nContent-Type: text/plain; charset=utf-8\r\nTransfer-Encoding: chunked\r\nConnection: keep-alive\r\nKeep-Alive: timeout=20\r\nVary: Accept-Encoding\r\nContent-Encoding: gzip\r\n\r\n' + self.has_sent_header = True + return header + buf + + def get_data_from_http_header(self, buf): + ret_buf = b'' + lines = buf.split(b'\r\n') + if lines and len(lines) > 4: + hex_items = lines[0].split(b'%') + if hex_items and len(hex_items) > 1: + for index in range(1, len(hex_items)): + if len(hex_items[index]) != 2: + ret_buf += binascii.unhexlify(hex_items[index][:2]) + break + ret_buf += binascii.unhexlify(hex_items[index]) + return ret_buf + return b'' + + def not_match_return(self, buf): + self.has_sent_header = True + self.has_recv_header = True + if self.method == 'http_simple': + return (b'E', False, False) + return (buf, True, False) + + def server_decode(self, buf): + if self.has_recv_header: + return (buf, True, False) + + self.recv_buffer += buf + buf = self.recv_buffer + if len(buf) > 10: + if match_begin(buf, b'GET /') or match_begin(buf, b'POST /'): + if len(buf) > 65536: + self.recv_buffer = None + logging.warn('http_simple: over size') + return self.not_match_return(buf) + else: #not http header, run on original protocol + self.recv_buffer = None + logging.debug('http_simple: not match begin') + return self.not_match_return(buf) + else: + return (b'', True, False) + + datas = buf.split(b'\r\n\r\n', 1) + if datas and len(datas) > 1: + ret_buf = self.get_data_from_http_header(buf) + ret_buf += datas[1] + if len(ret_buf) >= 15: + self.has_recv_header = True + return (ret_buf, True, False) + return (b'', True, False) + else: + return (b'', True, False) + +class http2_simple(plain.plain): + def __init__(self, method): + self.method = method + self.has_sent_header = False + self.has_recv_header = False + self.raw_trans_sent = False + self.host = None + self.port = 0 + self.recv_buffer = b'' + + def client_encode(self, buf): + if self.raw_trans_sent: + return buf + self.send_buffer += buf + if not self.has_sent_header: + port = b'' + if self.server_info.port != 80: + port = b':' + common.to_bytes(str(self.server_info.port)) + self.has_sent_header = True + http_head = b"GET / HTTP/1.1\r\n" + http_head += b"Host: " + (self.server_info.param or self.server_info.host) + port + b"\r\n" + http_head += b"Connection: Upgrade, HTTP2-Settings\r\nUpgrade: h2c\r\n" + http_head += b"HTTP2-Settings: " + base64.urlsafe_b64encode(buf) + b"\r\n" + return http_head + b"\r\n" + if self.has_recv_header: + ret = self.send_buffer + self.send_buffer = b'' + self.raw_trans_sent = True + return ret + return b'' + + def client_decode(self, buf): + if self.has_recv_header: + return (buf, False) + pos = buf.find(b'\r\n\r\n') + if pos >= 0: + self.has_recv_header = True + return (buf[pos + 4:], False) + else: + return (b'', False) + + def server_encode(self, buf): + if self.has_sent_header: + return buf + + header = b'HTTP/1.1 101 Switching Protocols\r\nConnection: Upgrade\r\nUpgrade: h2c\r\n\r\n' + self.has_sent_header = True + return header + buf + + def not_match_return(self, buf): + self.has_sent_header = True + self.has_recv_header = True + if self.method == 'http2_simple': + return (b'E', False, False) + return (buf, True, False) + + def server_decode(self, buf): + if self.has_recv_header: + return (buf, True, False) + + self.recv_buffer += buf + buf = self.recv_buffer + if len(buf) > 10: + if match_begin(buf, b'GET /'): + pass + else: #not http header, run on original protocol + self.recv_buffer = None + return self.not_match_return(buf) + else: + return (b'', True, False) + + datas = buf.split(b'\r\n\r\n', 1) + if datas and len(datas) > 1 and len(datas[0]) >= 4: + lines = buf.split(b'\r\n') + if lines and len(lines) >= 4: + if match_begin(lines[4], b'HTTP2-Settings: '): + ret_buf = base64.urlsafe_b64decode(lines[4][16:]) + ret_buf += datas[1] + self.has_recv_header = True + return (ret_buf, True, False) + return (b'', True, False) + else: + return (b'', True, False) + return self.not_match_return(buf) + +class tls_simple(plain.plain): + def __init__(self, method): + self.method = method + self.has_sent_header = False + self.has_recv_header = False + self.raw_trans_sent = False + + def client_encode(self, buf): + if self.raw_trans_sent: + return buf + self.send_buffer += buf + if not self.has_sent_header: + self.has_sent_header = True + data = b"\x03\x03" + os.urandom(32) + binascii.unhexlify(b"000016c02bc02fc00ac009c013c01400330039002f0035000a0100006fff01000100000a00080006001700180019000b0002010000230000337400000010002900270568322d31360568322d31350568322d313402683208737064792f332e3108687474702f312e31000500050100000000000d001600140401050106010201040305030603020304020202") + data = b"\x01\x00" + struct.pack('>H', len(data)) + data + data = b"\x16\x03\x01" + struct.pack('>H', len(data)) + data + return data + if self.has_recv_header: + ret = self.send_buffer + self.send_buffer = b'' + self.raw_trans_sent = True + return ret + return b'' + + def client_decode(self, buf): + if self.has_recv_header: + return (buf, False) + self.has_recv_header = True + return (b'', True) + + def server_encode(self, buf): + if self.has_sent_header: + return buf + self.has_sent_header = True + # TODO + #server_hello = b'' + return b'\x16\x03\x01' + + def server_decode(self, buf): + if self.has_recv_header: + return (buf, True, False) + + self.has_recv_header = True + if not match_begin(buf, b'\x16\x03\x01'): + self.has_sent_header = True + if self.method == 'tls_simple': + return (b'E', False, False) + return (buf, True, False) + # (buffer_to_recv, is_need_decrypt, is_need_to_encode_and_send_back) + return (b'', False, True) + +class random_head(plain.plain): + def __init__(self, method): + self.method = method + self.has_sent_header = False + self.has_recv_header = False + self.raw_trans_sent = False + self.raw_trans_recv = False + self.send_buffer = b'' + + def client_encode(self, buf): + if self.raw_trans_sent: + return buf + self.send_buffer += buf + if not self.has_sent_header: + self.has_sent_header = True + data = os.urandom(common.ord(os.urandom(1)[0]) % 96 + 4) + crc = (0xffffffff - binascii.crc32(data)) & 0xffffffff + return data + struct.pack('= len(str2): + if str1[:len(str2)] == str2: + return True + return False + +class obfs_verify_data(object): + def __init__(self): + pass + +class verify_base(plain.plain): + def __init__(self, method): + super(verify_base, self).__init__(method) + self.method = method + + def init_data(self): + return obfs_verify_data() + + def set_server_info(self, server_info): + self.server_info = server_info + + def client_encode(self, buf): + return buf + + def client_decode(self, buf): + return (buf, False) + + def server_encode(self, buf): + return buf + + def server_decode(self, buf): + return (buf, True, False) + + def get_head_size(self, buf, def_value): + if len(buf) < 2: + return def_value + if ord(buf[0]) == 1: + return 7 + if ord(buf[0]) == 4: + return 19 + if ord(buf[0]) == 3: + return 4 + ord(buf[1]) + return def_value + +class verify_simple(verify_base): + def __init__(self, method): + super(verify_simple, self).__init__(method) + self.recv_buf = b'' + self.unit_len = 8100 + self.decrypt_packet_num = 0 + self.raw_trans = False + + def pack_data(self, buf): + if len(buf) == 0: + return b'' + rnd_data = os.urandom(common.ord(os.urandom(1)[0]) % 16) + data = common.chr(len(rnd_data) + 1) + rnd_data + buf + data = struct.pack('>H', len(data) + 6) + data + crc = (0xffffffff - binascii.crc32(data)) & 0xffffffff + data += struct.pack(' self.unit_len: + ret += self.pack_data(buf[:self.unit_len]) + buf = buf[self.unit_len:] + ret += self.pack_data(buf) + return ret + + def client_post_decrypt(self, buf): + if self.raw_trans: + return buf + self.recv_buf += buf + out_buf = b'' + while len(self.recv_buf) > 2: + length = struct.unpack('>H', self.recv_buf[:2])[0] + if length >= 8192: + self.raw_trans = True + self.recv_buf = b'' + if self.decrypt_packet_num == 0: + return None + else: + raise Exception('server_post_decrype data error') + if length > len(self.recv_buf): + break + + if (binascii.crc32(self.recv_buf[:length]) & 0xffffffff) != 0xffffffff: + self.raw_trans = True + self.recv_buf = b'' + if self.decrypt_packet_num == 0: + return None + else: + raise Exception('server_post_decrype data uncorrect CRC32') + + pos = common.ord(self.recv_buf[2]) + 2 + out_buf += self.recv_buf[pos:length - 4] + self.recv_buf = self.recv_buf[length:] + + if out_buf: + self.decrypt_packet_num += 1 + return out_buf + + def server_pre_encrypt(self, buf): + ret = b'' + while len(buf) > self.unit_len: + ret += self.pack_data(buf[:self.unit_len]) + buf = buf[self.unit_len:] + ret += self.pack_data(buf) + return ret + + def server_post_decrypt(self, buf): + if self.raw_trans: + return buf + self.recv_buf += buf + out_buf = b'' + while len(self.recv_buf) > 2: + length = struct.unpack('>H', self.recv_buf[:2])[0] + if length >= 8192: + self.raw_trans = True + self.recv_buf = b'' + if self.decrypt_packet_num == 0: + return b'E' + else: + raise Exception('server_post_decrype data error') + if length > len(self.recv_buf): + break + + if (binascii.crc32(self.recv_buf[:length]) & 0xffffffff) != 0xffffffff: + self.raw_trans = True + self.recv_buf = b'' + if self.decrypt_packet_num == 0: + return b'E' + else: + raise Exception('server_post_decrype data uncorrect CRC32') + + pos = common.ord(self.recv_buf[2]) + 2 + out_buf += self.recv_buf[pos:length - 4] + self.recv_buf = self.recv_buf[length:] + + if out_buf: + self.decrypt_packet_num += 1 + return out_buf + +class verify_deflate(verify_base): + def __init__(self, method): + super(verify_deflate, self).__init__(method) + self.recv_buf = b'' + self.unit_len = 32700 + self.decrypt_packet_num = 0 + self.raw_trans = False + + def pack_data(self, buf): + if len(buf) == 0: + return b'' + data = zlib.compress(buf) + data = struct.pack('>H', len(data)) + data[2:] + return data + + def client_pre_encrypt(self, buf): + ret = b'' + while len(buf) > self.unit_len: + ret += self.pack_data(buf[:self.unit_len]) + buf = buf[self.unit_len:] + ret += self.pack_data(buf) + return ret + + def client_post_decrypt(self, buf): + if self.raw_trans: + return buf + self.recv_buf += buf + out_buf = b'' + while len(self.recv_buf) > 2: + length = struct.unpack('>H', self.recv_buf[:2])[0] + if length >= 32768: + self.raw_trans = True + self.recv_buf = b'' + if self.decrypt_packet_num == 0: + return None + else: + raise Exception('server_post_decrype data error') + if length > len(self.recv_buf): + break + + out_buf += zlib.decompress(b'x\x9c' + self.recv_buf[2:length]) + self.recv_buf = self.recv_buf[length:] + + if out_buf: + self.decrypt_packet_num += 1 + return out_buf + + def server_pre_encrypt(self, buf): + ret = b'' + while len(buf) > self.unit_len: + ret += self.pack_data(buf[:self.unit_len]) + buf = buf[self.unit_len:] + ret += self.pack_data(buf) + return ret + + def server_post_decrypt(self, buf): + if self.raw_trans: + return buf + self.recv_buf += buf + out_buf = b'' + while len(self.recv_buf) > 2: + length = struct.unpack('>H', self.recv_buf[:2])[0] + if length >= 32768: + self.raw_trans = True + self.recv_buf = b'' + if self.decrypt_packet_num == 0: + return None + else: + raise Exception('server_post_decrype data error') + if length > len(self.recv_buf): + break + + out_buf += zlib.decompress(b'\x78\x9c' + self.recv_buf[2:length]) + self.recv_buf = self.recv_buf[length:] + + if out_buf: + self.decrypt_packet_num += 1 + return out_buf + +class client_queue(object): + def __init__(self, begin_id): + self.front = begin_id + self.back = begin_id + self.alloc = {} + self.enable = True + self.last_update = time.time() + + def update(self): + self.last_update = time.time() + + def is_active(self): + return time.time() - self.last_update < 60 * 3 + + def re_enable(self, connection_id): + self.enable = True + self.alloc = {} + self.front = connection_id + self.back = connection_id + + def insert(self, connection_id): + self.update() + if not self.enable: + logging.warn('auth_simple: not enable') + return False + if connection_id < self.front: + logging.warn('auth_simple: duplicate id') + return False + if not self.is_active(): + self.re_enable(connection_id) + if connection_id > self.front + 0x4000: + logging.warn('auth_simple: wrong id') + return False + if connection_id in self.alloc: + logging.warn('auth_simple: duplicate id 2') + return False + if self.back <= connection_id: + self.back = connection_id + 1 + self.alloc[connection_id] = 1 + while (self.front in self.alloc) or self.front + 0x1000 < self.back: + if self.front in self.alloc: + del self.alloc[self.front] + self.front += 1 + return True + +class obfs_auth_data(object): + def __init__(self): + self.client_id = {} + self.startup_time = int(time.time() - 30) & 0xFFFFFFFF + self.local_client_id = b'' + self.connection_id = 0 + self.max_client = 16 # max active client count + self.max_buffer = max(self.max_client, 256) # max client id buffer size + + def update(self, client_id, connection_id): + if client_id in self.client_id: + self.client_id[client_id].update() + + def insert(self, client_id, connection_id): + if client_id not in self.client_id or not self.client_id[client_id].enable: + active = 0 + for c_id in self.client_id: + if self.client_id[c_id].is_active(): + active += 1 + if active >= self.max_client: + logging.warn('auth_simple: max active clients exceeded') + return False + + if len(self.client_id) < self.max_client: + if client_id not in self.client_id: + self.client_id[client_id] = client_queue(connection_id) + else: + self.client_id[client_id].re_enable(connection_id) + return self.client_id[client_id].insert(connection_id) + keys = self.client_id.keys() + random.shuffle(keys) + for c_id in keys: + if not self.client_id[c_id].is_active() and self.client_id[c_id].enable: + if len(self.client_id) >= self.max_buffer: + del self.client_id[c_id] + else: + self.client_id[c_id].enable = False + if client_id not in self.client_id: + self.client_id[client_id] = client_queue(connection_id) + else: + self.client_id[client_id].re_enable(connection_id) + return self.client_id[client_id].insert(connection_id) + logging.warn('auth_simple: no inactive client [assert]') + return False + else: + return self.client_id[client_id].insert(connection_id) + +class auth_simple(verify_base): + def __init__(self, method): + super(auth_simple, self).__init__(method) + self.recv_buf = b'' + self.unit_len = 8100 + self.decrypt_packet_num = 0 + self.raw_trans = False + self.has_sent_header = False + self.has_recv_header = False + self.client_id = 0 + self.connection_id = 0 + self.max_time_dif = 60 * 5 # time dif (second) setting + + def init_data(self): + return obfs_auth_data() + + def pack_data(self, buf): + if len(buf) == 0: + return b'' + rnd_data = os.urandom(common.ord(os.urandom(1)[0]) % 16) + data = common.chr(len(rnd_data) + 1) + rnd_data + buf + data = struct.pack('>H', len(data) + 6) + data + crc = (0xffffffff - binascii.crc32(data)) & 0xffffffff + data += struct.pack(' 0xFF000000: + self.server_info.data.local_client_id = b'' + if not self.server_info.data.local_client_id: + self.server_info.data.local_client_id = os.urandom(4) + logging.debug("local_client_id %s" % (binascii.hexlify(self.server_info.data.local_client_id),)) + self.server_info.data.connection_id = struct.unpack(' self.unit_len: + ret += self.pack_data(buf[:self.unit_len]) + buf = buf[self.unit_len:] + ret += self.pack_data(buf) + return ret + + def client_post_decrypt(self, buf): + if self.raw_trans: + return buf + self.recv_buf += buf + out_buf = b'' + while len(self.recv_buf) > 2: + length = struct.unpack('>H', self.recv_buf[:2])[0] + if length >= 8192: + self.raw_trans = True + self.recv_buf = b'' + if self.decrypt_packet_num == 0: + return None + else: + raise Exception('server_post_decrype data error') + if length > len(self.recv_buf): + break + + if (binascii.crc32(self.recv_buf[:length]) & 0xffffffff) != 0xffffffff: + self.raw_trans = True + self.recv_buf = b'' + if self.decrypt_packet_num == 0: + return None + else: + raise Exception('server_post_decrype data uncorrect CRC32') + + pos = common.ord(self.recv_buf[2]) + 2 + out_buf += self.recv_buf[pos:length - 4] + self.recv_buf = self.recv_buf[length:] + + if out_buf: + self.decrypt_packet_num += 1 + return out_buf + + def server_pre_encrypt(self, buf): + ret = b'' + while len(buf) > self.unit_len: + ret += self.pack_data(buf[:self.unit_len]) + buf = buf[self.unit_len:] + ret += self.pack_data(buf) + return ret + + def server_post_decrypt(self, buf): + if self.raw_trans: + return buf + self.recv_buf += buf + out_buf = b'' + while len(self.recv_buf) > 2: + length = struct.unpack('>H', self.recv_buf[:2])[0] + if length >= 8192: + self.raw_trans = True + self.recv_buf = b'' + if self.decrypt_packet_num == 0: + logging.info('auth_simple: over size') + return b'E' + else: + raise Exception('server_post_decrype data error') + if length > len(self.recv_buf): + break + + if (binascii.crc32(self.recv_buf[:length]) & 0xffffffff) != 0xffffffff: + logging.info('auth_simple: crc32 error, data %s' % (binascii.hexlify(self.recv_buf[:length]),)) + self.raw_trans = True + self.recv_buf = b'' + if self.decrypt_packet_num == 0: + return b'E' + else: + raise Exception('server_post_decrype data uncorrect CRC32') + + pos = common.ord(self.recv_buf[2]) + 2 + out_buf += self.recv_buf[pos:length - 4] + if not self.has_recv_header: + if len(out_buf) < 12: + self.raw_trans = True + self.recv_buf = b'' + logging.info('auth_simple: too short') + return b'E' + utc_time = struct.unpack(' self.max_time_dif \ + or common.int32(utc_time - self.server_info.data.startup_time) < 0: + self.raw_trans = True + self.recv_buf = b'' + logging.info('auth_simple: wrong timestamp, time_dif %d, data %s' % (time_dif, binascii.hexlify(out_buf),)) + return b'E' + elif self.server_info.data.insert(client_id, connection_id): + self.has_recv_header = True + out_buf = out_buf[12:] + self.client_id = client_id + self.connection_id = connection_id + else: + self.raw_trans = True + self.recv_buf = b'' + logging.info('auth_simple: auth fail, data %s' % (binascii.hexlify(out_buf),)) + return b'E' + self.recv_buf = self.recv_buf[length:] + + if out_buf: + self.server_info.data.update(self.client_id, self.connection_id) + self.decrypt_packet_num += 1 + return out_buf + diff --git a/shadowsocks/run.sh b/shadowsocks/run.sh new file mode 100644 index 00000000..497ceb84 --- /dev/null +++ b/shadowsocks/run.sh @@ -0,0 +1,5 @@ +#!/bin/bash +cd `dirname $0` +eval $(ps -ef | grep "[0-9] python server\\.py a" | awk '{print "kill "$2}') +nohup python server.py a >> ssserver.log 2>&1 & + diff --git a/shadowsocks/server.py b/shadowsocks/server.py index 429a20a3..4c19474f 100755 --- a/shadowsocks/server.py +++ b/shadowsocks/server.py @@ -23,8 +23,14 @@ import logging import signal -sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../')) -from shadowsocks import shell, daemon, eventloop, tcprelay, udprelay, asyncdns +if __name__ == '__main__': + import inspect + file_path = os.path.dirname(os.path.realpath(inspect.getfile(inspect.currentframe()))) + os.chdir(file_path) + sys.path.insert(0, os.path.join(file_path, '../')) + +from shadowsocks import shell, daemon, eventloop, tcprelay, udprelay, \ + asyncdns, manager def main(): @@ -48,17 +54,70 @@ def main(): else: config['port_password'][str(server_port)] = config['password'] + if not config.get('dns_ipv6', False): + asyncdns.IPV6_CONNECTION_SUPPORT = False + + if config.get('manager_address', 0): + logging.info('entering manager mode') + manager.run(config) + return + tcp_servers = [] udp_servers = [] dns_resolver = asyncdns.DNSResolver() - for port, password in config['port_password'].items(): + port_password = config['port_password'] + del config['port_password'] + for port, password_obfs in port_password.items(): + protocol = config.get("protocol", 'origin') + obfs_param = config.get("obfs_param", '') + if type(password_obfs) == list: + password = password_obfs[0] + obfs = password_obfs[1] + elif type(password_obfs) == dict: + password = password_obfs.get('password', 'm') + protocol = password_obfs.get('protocol', 'origin') + obfs = password_obfs.get('obfs', 'plain') + obfs_param = password_obfs.get('obfs_param', '') + else: + password = password_obfs + obfs = config["obfs"] a_config = config.copy() - a_config['server_port'] = int(port) - a_config['password'] = password - logging.info("starting server at %s:%d" % - (a_config['server'], int(port))) - tcp_servers.append(tcprelay.TCPRelay(a_config, dns_resolver, False)) - udp_servers.append(udprelay.UDPRelay(a_config, dns_resolver, False)) + ipv6_ok = False + logging.info("server start with protocol[%s] password [%s] method [%s] obfs [%s] obfs_param [%s]" % + (protocol, password, a_config['method'], obfs, obfs_param)) + if 'server_ipv6' in a_config: + try: + if len(a_config['server_ipv6']) > 2 and a_config['server_ipv6'][0] == "[" and a_config['server_ipv6'][-1] == "]": + a_config['server_ipv6'] = a_config['server_ipv6'][1:-1] + a_config['server_port'] = int(port) + a_config['password'] = password + a_config['protocol'] = protocol + a_config['obfs'] = obfs + a_config['obfs_param'] = obfs_param + a_config['server'] = a_config['server_ipv6'] + logging.info("starting server at [%s]:%d" % + (a_config['server'], int(port))) + tcp_servers.append(tcprelay.TCPRelay(a_config, dns_resolver, False)) + udp_servers.append(udprelay.UDPRelay(a_config, dns_resolver, False)) + if a_config['server_ipv6'] == b"::": + ipv6_ok = True + except Exception as e: + shell.print_exception(e) + + try: + a_config = config.copy() + a_config['server_port'] = int(port) + a_config['password'] = password + a_config['protocol'] = protocol + a_config['obfs'] = obfs + a_config['obfs_param'] = obfs_param + logging.info("starting server at %s:%d" % + (a_config['server'], int(port))) + tcp_servers.append(tcprelay.TCPRelay(a_config, dns_resolver, False)) + udp_servers.append(udprelay.UDPRelay(a_config, dns_resolver, False)) + except Exception as e: + if not ipv6_ok: + shell.print_exception(e) def run_server(): def child_handler(signum, _): diff --git a/shadowsocks/shell.py b/shadowsocks/shell.py index f8ae81f5..38d2432b 100644 --- a/shadowsocks/shell.py +++ b/shadowsocks/shell.py @@ -130,13 +130,13 @@ def get_config(is_local): logging.basicConfig(level=logging.INFO, format='%(levelname)-s: %(message)s') if is_local: - shortopts = 'hd:s:b:p:k:l:m:c:t:vq' + shortopts = 'hd:s:b:p:k:l:m:o:c:t:vq' longopts = ['help', 'fast-open', 'pid-file=', 'log-file=', 'user=', 'version'] else: - shortopts = 'hd:s:p:k:m:c:t:vq' + shortopts = 'hd:s:p:k:m:o:c:t:vq' longopts = ['help', 'fast-open', 'pid-file=', 'log-file=', 'workers=', - 'forbidden-ip=', 'user=', 'version'] + 'forbidden-ip=', 'user=', 'manager-address=', 'version'] try: config_path = find_config() optlist, args = getopt.getopt(sys.argv[1:], shortopts, longopts) @@ -148,8 +148,7 @@ def get_config(is_local): logging.info('loading config from %s' % config_path) with open(config_path, 'rb') as f: try: - config = json.loads(f.read().decode('utf8'), - object_hook=_decode_dict) + config = parse_json_in_str(f.read().decode('utf8')) except ValueError as e: logging.error('found an error in config.json: %s', e.message) @@ -169,6 +168,8 @@ def get_config(is_local): config['server'] = to_str(value) elif key == '-m': config['method'] = to_str(value) + elif key == '-o': + config['obfs'] = to_str(value) elif key == '-b': config['local_address'] = to_str(value) elif key == '-v': @@ -181,6 +182,8 @@ def get_config(is_local): config['fast_open'] = True elif key == '--workers': config['workers'] = int(value) + elif key == '--manager-address': + config['manager_address'] = value elif key == '--user': config['user'] = to_str(value) elif key == '--forbidden-ip': @@ -215,6 +218,9 @@ def get_config(is_local): config['password'] = to_bytes(config.get('password', b'')) config['method'] = to_str(config.get('method', 'aes-256-cfb')) + config['protocol'] = to_str(config.get('protocol', 'origin')) + config['obfs'] = to_str(config.get('obfs', 'plain')) + config['obfs_param'] = to_str(config.get('obfs_param', '')) config['port_password'] = config.get('port_password', None) config['timeout'] = int(config.get('timeout', 300)) config['fast_open'] = config.get('fast_open', False) @@ -284,6 +290,7 @@ def print_local_help(): -l LOCAL_PORT local port, default: 1080 -k PASSWORD password -m METHOD encryption method, default: aes-256-cfb + -o OBFS obfsplugin, default: http_simple -t TIMEOUT timeout in seconds, default: 300 --fast-open use TCP_FASTOPEN, requires Linux 3.7+ @@ -313,10 +320,12 @@ def print_server_help(): -p SERVER_PORT server port, default: 8388 -k PASSWORD password -m METHOD encryption method, default: aes-256-cfb + -o OBFS obfsplugin, default: http_simple -t TIMEOUT timeout in seconds, default: 300 --fast-open use TCP_FASTOPEN, requires Linux 3.7+ --workers WORKERS number of workers, available on Unix/Linux --forbidden-ip IPLIST comma seperated IP list forbidden to connect + --manager-address ADDR optional server manager UDP address, see wiki General options: -h, --help show this help message and exit @@ -356,3 +365,8 @@ def _decode_dict(data): value = _decode_dict(value) rv[key] = value return rv + + +def parse_json_in_str(data): + # parse json and convert everything from unicode to str + return json.loads(data, object_hook=_decode_dict) diff --git a/shadowsocks/stop.sh b/shadowsocks/stop.sh new file mode 100644 index 00000000..af1fbf92 --- /dev/null +++ b/shadowsocks/stop.sh @@ -0,0 +1,3 @@ +#!/bin/bash + +eval $(ps -ef | grep "[0-9] python server\\.py a" | awk '{print "kill "$2}') diff --git a/shadowsocks/tail.sh b/shadowsocks/tail.sh new file mode 100644 index 00000000..aa371393 --- /dev/null +++ b/shadowsocks/tail.sh @@ -0,0 +1,3 @@ +#!/bin/bash + +tail -f ssserver.log diff --git a/shadowsocks/tcprelay.py b/shadowsocks/tcprelay.py index 4834883a..395ecbf9 100644 --- a/shadowsocks/tcprelay.py +++ b/shadowsocks/tcprelay.py @@ -23,18 +23,19 @@ import errno import struct import logging +import binascii import traceback import random -from shadowsocks import encrypt, eventloop, shell, common -from shadowsocks.common import parse_header +from shadowsocks import encrypt, obfs, eventloop, shell, common +from shadowsocks.common import pre_parse_header, parse_header + +# set it 'True' if run as a local client and connect to a server which support new protocol +CLIENT_NEW_PROTOCOL = False #deprecated # we clear at most TIMEOUTS_CLEAN_SIZE timeouts each time TIMEOUTS_CLEAN_SIZE = 512 -# we check timeouts every TIMEOUT_PRECISION seconds -TIMEOUT_PRECISION = 4 - MSG_FASTOPEN = 0x20000000 # SOCKS command definition @@ -101,6 +102,8 @@ def __init__(self, server, fd_to_handlers, loop, local_sock, config, self._loop = loop self._local_sock = local_sock self._remote_sock = None + self._remote_sock_v6 = None + self._remote_udp = False self._config = config self._dns_resolver = dns_resolver @@ -110,9 +113,27 @@ def __init__(self, server, fd_to_handlers, loop, local_sock, config, self._stage = STAGE_INIT self._encryptor = encrypt.Encryptor(config['password'], config['method']) + self._encrypt_correct = True + self._obfs = obfs.obfs(config['obfs']) + server_info = obfs.server_info(server.obfs_data) + server_info.host = config['server'] + server_info.port = server._listen_port + server_info.tcp_mss = 1440 + server_info.param = config['obfs_param'] + self._obfs.set_server_info(server_info) + + self._protocol = obfs.obfs(config['protocol']) + server_info = obfs.server_info(server.protocol_data) + server_info.host = config['server'] + server_info.port = server._listen_port + server_info.tcp_mss = 1440 + server_info.param = '' + self._protocol.set_server_info(server_info) + self._fastopen_connected = False self._data_to_write_to_local = [] self._data_to_write_to_remote = [] + self._udp_data_send_buffer = b'' self._upstream_status = WAIT_STATUS_READING self._downstream_status = WAIT_STATUS_INIT self._client_address = local_sock.getpeername()[:2] @@ -126,9 +147,11 @@ def __init__(self, server, fd_to_handlers, loop, local_sock, config, fd_to_handlers[local_sock.fileno()] = self local_sock.setblocking(False) local_sock.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1) - loop.add(local_sock, eventloop.POLL_IN | eventloop.POLL_ERR) + loop.add(local_sock, eventloop.POLL_IN | eventloop.POLL_ERR, + self._server) self.last_activity = 0 self._update_activity() + self._server.add_connection(1) def __hash__(self): # default __hash__ is id / 16 @@ -149,10 +172,10 @@ def _get_a_server(self): logging.debug('chosen server: %s:%d', server, server_port) return server, server_port - def _update_activity(self): + def _update_activity(self, data_len=0): # tell the TCP Relay we have activities recently # else it will think we are inactive and timed out - self._server.update_activity(self) + self._server.update_activity(self, data_len) def _update_stream(self, stream, status): # update a stream to a new waiting status @@ -183,26 +206,90 @@ def _update_stream(self, stream, status): if self._upstream_status & WAIT_STATUS_WRITING: event |= eventloop.POLL_OUT self._loop.modify(self._remote_sock, event) + if self._remote_sock_v6: + self._loop.modify(self._remote_sock_v6, event) def _write_to_sock(self, data, sock): # write data to sock # if only some of the data are written, put remaining in the buffer # and update the stream to wait for writing - if not data or not sock: + if not sock: return False + #logging.debug("_write_to_sock %s %s %s" % (self._remote_sock, sock, self._remote_udp)) uncomplete = False - try: - l = len(data) - s = sock.send(data) - if s < l: - data = data[s:] - uncomplete = True - except (OSError, IOError) as e: - error_no = eventloop.errno_from_exception(e) - if error_no in (errno.EAGAIN, errno.EINPROGRESS, - errno.EWOULDBLOCK): - uncomplete = True - else: + if self._remote_udp and sock == self._remote_sock: + try: + self._udp_data_send_buffer += data + #logging.info('UDP over TCP sendto %d %s' % (len(data), binascii.hexlify(data))) + while len(self._udp_data_send_buffer) > 6: + length = struct.unpack('>H', self._udp_data_send_buffer[:2])[0] + + if length > len(self._udp_data_send_buffer): + break + + data = self._udp_data_send_buffer[:length] + self._udp_data_send_buffer = self._udp_data_send_buffer[length:] + + frag = common.ord(data[2]) + if frag != 0: + logging.warn('drop a message since frag is %d' % (frag,)) + continue + else: + data = data[3:] + header_result = parse_header(data) + if header_result is None: + continue + connecttype, dest_addr, dest_port, header_length = header_result + addrs = socket.getaddrinfo(dest_addr, dest_port, 0, + socket.SOCK_DGRAM, socket.SOL_UDP) + #logging.info('UDP over TCP sendto %s:%d %d bytes from %s:%d' % (dest_addr, dest_port, len(data), self._client_address[0], self._client_address[1])) + if addrs: + af, socktype, proto, canonname, server_addr = addrs[0] + data = data[header_length:] + if af == socket.AF_INET6: + self._remote_sock_v6.sendto(data, (server_addr[0], dest_port)) + else: + sock.sendto(data, (server_addr[0], dest_port)) + + except Exception as e: + #trace = traceback.format_exc() + #logging.error(trace) + error_no = eventloop.errno_from_exception(e) + if error_no in (errno.EAGAIN, errno.EINPROGRESS, + errno.EWOULDBLOCK): + uncomplete = True + else: + shell.print_exception(e) + self.destroy() + return False + return True + else: + try: + if self._is_local: + pass + else: + if sock == self._local_sock and self._encrypt_correct: + obfs_encode = self._obfs.server_encode(data) + data = obfs_encode + if data: + l = len(data) + s = sock.send(data) + if s < l: + data = data[s:] + uncomplete = True + else: + return + except (OSError, IOError) as e: + error_no = eventloop.errno_from_exception(e) + if error_no in (errno.EAGAIN, errno.EINPROGRESS, + errno.EWOULDBLOCK): + uncomplete = True + else: + #traceback.print_exc() + shell.print_exception(e) + self.destroy() + return False + except Exception as e: shell.print_exception(e) self.destroy() return False @@ -224,10 +311,38 @@ def _write_to_sock(self, data, sock): logging.error('write_all_to_sock:unknown socket') return True + def _get_redirect_host(self, client_address, ogn_data): + # test + host_list = [(b"www.bing.com", 80), (b"www.microsoft.com", 80), (b"cloudfront.com", 80), (b"cloudflare.com", 80), (b"1.2.3.4", 1000), (b"0.0.0.0", 0)] + hash_code = binascii.crc32(ogn_data) + addrs = socket.getaddrinfo(client_address[0], client_address[1], 0, socket.SOCK_STREAM, socket.SOL_TCP) + af, socktype, proto, canonname, sa = addrs[0] + address_bytes = common.inet_pton(af, sa[0]) + if len(address_bytes) == 16: + addr = struct.unpack('>Q', address_bytes[8:])[0] + if len(address_bytes) == 4: + addr = struct.unpack('>I', address_bytes)[0] + else: + addr = 0 + return host_list[((hash_code & 0xffffffff) + addr + 3) % len(host_list)] + + def _handel_protocol_error(self, client_address, ogn_data): + #raise Exception('can not parse header') + logging.warn("Protocol ERROR, TCP ogn data %s from %s:%d" % (binascii.hexlify(ogn_data), client_address[0], client_address[1])) + self._encrypt_correct = False + #create redirect or disconnect by hash code + host, port = self._get_redirect_host(client_address, ogn_data) + data = b"\x03" + common.chr(len(host)) + host + struct.pack('>H', port) + logging.warn("TCP data redir %s:%d %s" % (host, port, binascii.hexlify(data))) + return data + ogn_data + def _handle_stage_connecting(self, data): if self._is_local: + data = self._protocol.client_pre_encrypt(data) data = self._encryptor.encrypt(data) - self._data_to_write_to_remote.append(data) + data = self._obfs.client_encode(data) + if data: + self._data_to_write_to_remote.append(data) if self._is_local and not self._fastopen_connected and \ self._config['fast_open']: # for sslocal and fastopen, we basically wait for data and use @@ -238,7 +353,7 @@ def _handle_stage_connecting(self, data): remote_sock = \ self._create_remote_socket(self._chosen_server[0], self._chosen_server[1]) - self._loop.add(remote_sock, eventloop.POLL_ERR) + self._loop.add(remote_sock, eventloop.POLL_ERR, self._server) data = b''.join(self._data_to_write_to_remote) l = len(data) s = remote_sock.sendto(data, MSG_FASTOPEN, self._chosen_server) @@ -262,7 +377,7 @@ def _handle_stage_connecting(self, data): traceback.print_exc() self.destroy() - def _handle_stage_addr(self, data): + def _handle_stage_addr(self, ogn_data, data): try: if self._is_local: cmd = common.ord(data[1]) @@ -288,14 +403,25 @@ def _handle_stage_addr(self, data): logging.error('unknown command %d', cmd) self.destroy() return - header_result = parse_header(data) - if header_result is None: - raise Exception('can not parse header') - addrtype, remote_addr, remote_port, header_length = header_result - logging.info('connecting %s:%d from %s:%d' % - (common.to_str(remote_addr), remote_port, - self._client_address[0], self._client_address[1])) + + before_parse_data = data + if self._is_local: + header_result = parse_header(data) + else: + data = pre_parse_header(data) + if data is None: + data = self._handel_protocol_error(self._client_address, ogn_data) + header_result = parse_header(data) + if header_result is None: + data = self._handel_protocol_error(self._client_address, ogn_data) + header_result = parse_header(data) + connecttype, remote_addr, remote_port, header_length = header_result + logging.info('%s connecting %s:%d from %s:%d' % + ((connecttype == 0) and 'TCP' or 'UDP', + common.to_str(remote_addr), remote_port, + self._client_address[0], self._client_address[1])) self._remote_address = (common.to_str(remote_addr), remote_port) + self._remote_udp = (connecttype != 0) # pause reading self._update_stream(STREAM_UP, WAIT_STATUS_WRITING) self._stage = STAGE_DNS @@ -304,8 +430,17 @@ def _handle_stage_addr(self, data): self._write_to_sock((b'\x05\x00\x00\x01' b'\x00\x00\x00\x00\x10\x10'), self._local_sock) + if CLIENT_NEW_PROTOCOL: + rnd_len = random.randint(1, 32) + total_len = 7 + rnd_len + len(data) + data = b'\x88' + struct.pack('>H', total_len) + chr(rnd_len) + (b' ' * (rnd_len - 1)) + data + crc = (0xffffffff - binascii.crc32(data)) & 0xffffffff + data += struct.pack('H', addr[1]) + try: + ip = socket.inet_aton(addr[0]) + data = b'\x00\x01' + ip + port + data + except Exception as e: + ip = socket.inet_pton(socket.AF_INET6, addr[0]) + data = b'\x00\x04' + ip + port + data + data = struct.pack('>H', len(data) + 2) + data + #logging.info('UDP over TCP recvfrom %s:%d %d bytes to %s:%d' % (addr[0], addr[1], len(data), self._client_address[0], self._client_address[1])) + else: + data = self._remote_sock.recv(BUF_SIZE) except (OSError, IOError) as e: if eventloop.errno_from_exception(e) in \ - (errno.ETIMEDOUT, errno.EAGAIN, errno.EWOULDBLOCK): + (errno.ETIMEDOUT, errno.EAGAIN, errno.EWOULDBLOCK, 10035): #errno.WSAEWOULDBLOCK return if not data: self.destroy() return + self._server.server_transfer_dl += len(data) + self._update_activity(len(data)) if self._is_local: - data = self._encryptor.decrypt(data) + obfs_decode = self._obfs.client_decode(data) + if obfs_decode[1]: + send_back = self._obfs.client_encode(b'') + self._write_to_sock(send_back, self._remote_sock) + data = self._encryptor.decrypt(obfs_decode[0]) + data = self._protocol.client_post_decrypt(data) else: - data = self._encryptor.encrypt(data) + if self._encrypt_correct: + data = self._protocol.server_pre_encrypt(data) + data = self._encryptor.encrypt(data) try: self._write_to_sock(data, self._local_sock) except Exception as e: @@ -486,13 +693,13 @@ def handle_event(self, sock, event): logging.debug('ignore handle_event: destroyed') return # order is important - if sock == self._remote_sock: + if sock == self._remote_sock or sock == self._remote_sock_v6: if event & eventloop.POLL_ERR: self._on_remote_error() if self._stage == STAGE_DESTROYED: return if event & (eventloop.POLL_IN | eventloop.POLL_HUP): - self._on_remote_read() + self._on_remote_read(sock == self._remote_sock) if self._stage == STAGE_DESTROYED: return if event & eventloop.POLL_OUT: @@ -535,29 +742,51 @@ def destroy(self): logging.debug('destroy') if self._remote_sock: logging.debug('destroying remote') - self._loop.remove(self._remote_sock) + try: + self._loop.remove(self._remote_sock) + except Exception as e: + pass del self._fd_to_handlers[self._remote_sock.fileno()] self._remote_sock.close() self._remote_sock = None + if self._remote_sock_v6: + logging.debug('destroying remote') + try: + self._loop.remove(self._remote_sock_v6) + except Exception as e: + pass + del self._fd_to_handlers[self._remote_sock_v6.fileno()] + self._remote_sock_v6.close() + self._remote_sock_v6 = None if self._local_sock: logging.debug('destroying local') self._loop.remove(self._local_sock) del self._fd_to_handlers[self._local_sock.fileno()] self._local_sock.close() self._local_sock = None + if self._obfs: + self._obfs.dispose() + self._obfs = None + if self._protocol: + self._protocol.dispose() + self._protocol = None self._dns_resolver.remove_callback(self._handle_dns_resolved) self._server.remove_handler(self) - + self._server.add_connection(-1) class TCPRelay(object): - def __init__(self, config, dns_resolver, is_local): + def __init__(self, config, dns_resolver, is_local, stat_callback=None): self._config = config self._is_local = is_local self._dns_resolver = dns_resolver self._closed = False self._eventloop = None self._fd_to_handlers = {} - self._last_time = time.time() + self.server_transfer_ul = 0 + self.server_transfer_dl = 0 + self.server_connections = 0 + self.protocol_data = obfs.obfs(config['protocol']).init_data() + self.obfs_data = obfs.obfs(config['obfs']).init_data() self._timeout = config['timeout'] self._timeouts = [] # a list for all the handlers @@ -591,6 +820,7 @@ def __init__(self, config, dns_resolver, is_local): self._config['fast_open'] = False server_socket.listen(1024) self._server_socket = server_socket + self._stat_callback = stat_callback def add_to_loop(self, loop): if self._eventloop: @@ -598,10 +828,9 @@ def add_to_loop(self, loop): if self._closed: raise Exception('already closed') self._eventloop = loop - loop.add_handler(self._handle_events) - self._eventloop.add(self._server_socket, - eventloop.POLL_IN | eventloop.POLL_ERR) + eventloop.POLL_IN | eventloop.POLL_ERR, self) + self._eventloop.add_periodic(self.handle_periodic) def remove_handler(self, handler): index = self._handler_to_timeouts.get(hash(handler), -1) @@ -610,10 +839,17 @@ def remove_handler(self, handler): self._timeouts[index] = None del self._handler_to_timeouts[hash(handler)] - def update_activity(self, handler): + def add_connection(self, val): + self.server_connections += val + logging.debug('server port %5d connections = %d' % (self._listen_port, self.server_connections,)) + + def update_activity(self, handler, data_len): + if data_len and self._stat_callback: + self._stat_callback(self._listen_port, data_len) + # set handler to active now = int(time.time()) - if now - handler.last_activity < TIMEOUT_PRECISION: + if now - handler.last_activity < eventloop.TIMEOUT_PRECISION: # thus we can lower timeout modification frequency return handler.last_activity = now @@ -659,53 +895,57 @@ def _sweep_timeout(self): pos = 0 self._timeout_offset = pos - def _handle_events(self, events): + def handle_event(self, sock, fd, event): # handle events and dispatch to handlers - for sock, fd, event in events: + if sock: + logging.log(shell.VERBOSE_LEVEL, 'fd %d %s', fd, + eventloop.EVENT_NAMES.get(event, event)) + if sock == self._server_socket: + if event & eventloop.POLL_ERR: + # TODO + raise Exception('server_socket error') + try: + logging.debug('accept') + conn = self._server_socket.accept() + TCPRelayHandler(self, self._fd_to_handlers, + self._eventloop, conn[0], self._config, + self._dns_resolver, self._is_local) + except (OSError, IOError) as e: + error_no = eventloop.errno_from_exception(e) + if error_no in (errno.EAGAIN, errno.EINPROGRESS, + errno.EWOULDBLOCK): + return + else: + shell.print_exception(e) + if self._config['verbose']: + traceback.print_exc() + else: if sock: - logging.log(shell.VERBOSE_LEVEL, 'fd %d %s', fd, - eventloop.EVENT_NAMES.get(event, event)) - if sock == self._server_socket: - if event & eventloop.POLL_ERR: - # TODO - raise Exception('server_socket error') - try: - logging.debug('accept') - conn = self._server_socket.accept() - TCPRelayHandler(self, self._fd_to_handlers, - self._eventloop, conn[0], self._config, - self._dns_resolver, self._is_local) - except (OSError, IOError) as e: - error_no = eventloop.errno_from_exception(e) - if error_no in (errno.EAGAIN, errno.EINPROGRESS, - errno.EWOULDBLOCK): - continue - else: - shell.print_exception(e) - if self._config['verbose']: - traceback.print_exc() + handler = self._fd_to_handlers.get(fd, None) + if handler: + handler.handle_event(sock, event) else: - if sock: - handler = self._fd_to_handlers.get(fd, None) - if handler: - handler.handle_event(sock, event) - else: - logging.warn('poll removed fd') + logging.warn('poll removed fd') - now = time.time() - if now - self._last_time > TIMEOUT_PRECISION: - self._sweep_timeout() - self._last_time = now + def handle_periodic(self): if self._closed: if self._server_socket: self._eventloop.remove(self._server_socket) self._server_socket.close() self._server_socket = None - logging.info('closed listen port %d', self._listen_port) + logging.info('closed TCP port %d', self._listen_port) if not self._fd_to_handlers: - self._eventloop.remove_handler(self._handle_events) + logging.info('stopping') + self._eventloop.stop() + self._sweep_timeout() def close(self, next_tick=False): + logging.debug('TCP close') self._closed = True if not next_tick: + if self._eventloop: + self._eventloop.remove_periodic(self.handle_periodic) + self._eventloop.remove(self._server_socket) self._server_socket.close() + for handler in list(self._fd_to_handlers.values()): + handler.destroy() diff --git a/shadowsocks/udprelay.py b/shadowsocks/udprelay.py index 98bfaaa7..5519466b 100644 --- a/shadowsocks/udprelay.py +++ b/shadowsocks/udprelay.py @@ -68,20 +68,802 @@ import struct import errno import random +import binascii +import traceback from shadowsocks import encrypt, eventloop, lru_cache, common, shell -from shadowsocks.common import parse_header, pack_addr +from shadowsocks.common import pre_parse_header, parse_header, pack_addr +# we clear at most TIMEOUTS_CLEAN_SIZE timeouts each time +TIMEOUTS_CLEAN_SIZE = 512 + +# for each handler, we have 2 stream directions: +# upstream: from client to server direction +# read local and write to remote +# downstream: from server to client direction +# read remote and write to local + +STREAM_UP = 0 +STREAM_DOWN = 1 + +# for each stream, it's waiting for reading, or writing, or both +WAIT_STATUS_INIT = 0 +WAIT_STATUS_READING = 1 +WAIT_STATUS_WRITING = 2 +WAIT_STATUS_READWRITING = WAIT_STATUS_READING | WAIT_STATUS_WRITING BUF_SIZE = 65536 +DOUBLE_SEND_BEG_IDS = 16 +POST_MTU_MIN = 500 +POST_MTU_MAX = 1400 +SENDING_WINDOW_SIZE = 8192 + +STAGE_INIT = 0 +STAGE_RSP_ID = 1 +STAGE_DNS = 2 +STAGE_CONNECTING = 3 +STAGE_STREAM = 4 +STAGE_DESTROYED = -1 + +CMD_CONNECT = 0 +CMD_RSP_CONNECT = 1 +CMD_CONNECT_REMOTE = 2 +CMD_RSP_CONNECT_REMOTE = 3 +CMD_POST = 4 +CMD_SYN_STATUS = 5 +CMD_POST_64 = 6 +CMD_SYN_STATUS_64 = 7 +CMD_DISCONNECT = 8 + +CMD_VER_STR = b"\x08" + +RSP_STATE_EMPTY = b"" +RSP_STATE_REJECT = b"\x00" +RSP_STATE_CONNECTED = b"\x01" +RSP_STATE_CONNECTEDREMOTE = b"\x02" +RSP_STATE_ERROR = b"\x03" +RSP_STATE_DISCONNECT = b"\x04" +RSP_STATE_REDIRECT = b"\x05" + +class UDPLocalAddress(object): + def __init__(self, addr): + self.addr = addr + self.last_activity = time.time() + + def is_timeout(self): + return time.time() - self.last_activity > 30 + +class PacketInfo(object): + def __init__(self, data): + self.data = data + self.time = time.time() + +class SendingQueue(object): + def __init__(self): + self.queue = {} + self.begin_id = 0 + self.end_id = 1 + self.interval = 0.5 + + def append(self, data): + self.queue[self.end_id] = PacketInfo(data) + self.end_id += 1 + return self.end_id - 1 + + def empty(self): + return self.begin_id + 1 == self.end_id + + def size(self): + return self.end_id - self.begin_id - 1 + + def get_begin_id(self): + return self.begin_id + + def get_end_id(self): + return self.end_id + + def get_data_list(self, pack_id_base, pack_id_list): + ret_list = [] + curtime = time.time() + for pack_id in pack_id_list: + offset = pack_id_base + pack_id + if offset <= self.begin_id or self.end_id <= offset: + continue + ret_data = self.queue[offset] + if curtime - ret_data.time > self.interval: + ret_data.time = curtime + ret_list.append( (offset, ret_data.data) ) + return ret_list + + def set_finish(self, begin_id, done_list): + while self.begin_id < begin_id: + self.begin_id += 1 + del self.queue[self.begin_id] + +class RecvQueue(object): + def __init__(self): + self.queue = {} + self.miss_queue = set() + self.begin_id = 0 + self.end_id = 1 + + def empty(self): + return self.begin_id + 1 == self.end_id + + def insert(self, pack_id, data): + if (pack_id not in self.queue) and pack_id > self.begin_id: + self.queue[pack_id] = PacketInfo(data) + if self.end_id == pack_id: + self.end_id = pack_id + 1 + elif self.end_id < pack_id: + eid = self.end_id + while eid < pack_id: + self.miss_queue.add(eid) + eid += 1 + self.end_id = pack_id + 1 + else: + self.miss_queue.remove(pack_id) + + def set_end(self, end_id): + if end_id > self.end_id: + eid = self.end_id + while eid < end_id: + self.miss_queue.add(eid) + eid += 1 + self.end_id = end_id + + def get_begin_id(self): + return self.begin_id + + def has_data(self): + return (self.begin_id + 1) in self.queue + + def get_data(self): + if (self.begin_id + 1) in self.queue: + self.begin_id += 1 + pack_id = self.begin_id + ret_data = self.queue[pack_id] + del self.queue[pack_id] + return (pack_id, ret_data.data) + + def get_missing_id(self, begin_id): + missing = [] + if begin_id == 0: + begin_id = self.begin_id + for i in self.miss_queue: + if i - begin_id > 32768: + break + missing.append(i - begin_id) + return (begin_id, missing) + +class AddressMap(object): + def __init__(self): + self._queue = [] + self._addr_map = {} + + def add(self, addr): + if addr in self._addr_map: + self._addr_map[addr] = UDPLocalAddress(addr) + else: + self._addr_map[addr] = UDPLocalAddress(addr) + self._queue.append(addr) + + def keys(self): + return self._queue + + def get(self): + if self._queue: + while True: + if len(self._queue) == 1: + return self._queue[0] + index = random.randint(0, len(self._queue) - 1) + addr = self._queue[index] + if self._addr_map[addr].is_timeout(): + self._queue[index] = self._queue[len(self._queue) - 1] + del self._queue[len(self._queue) - 1] + del self._addr_map[addr] + else: + break + return addr + else: + return None + +class TCPRelayHandler(object): + def __init__(self, server, reqid_to_handlers, fd_to_handlers, loop, + local_sock, local_id, client_param, config, + dns_resolver, is_local): + self._server = server + self._reqid_to_handlers = reqid_to_handlers + self._fd_to_handlers = fd_to_handlers + self._loop = loop + self._local_sock = local_sock + self._remote_sock = None + self._remote_udp = False + self._config = config + self._dns_resolver = dns_resolver + self._local_id = local_id + + self._is_local = is_local + self._stage = STAGE_INIT + self._password = config['password'] + self._method = config['method'] + self._fastopen_connected = False + self._data_to_write_to_local = [] + self._data_to_write_to_remote = [] + self._upstream_status = WAIT_STATUS_READING + self._downstream_status = WAIT_STATUS_INIT + self._request_id = 0 + self._client_address = AddressMap() + self._remote_address = None + self._sendingqueue = SendingQueue() + self._recvqueue = RecvQueue() + if 'forbidden_ip' in config: + self._forbidden_iplist = config['forbidden_ip'] + else: + self._forbidden_iplist = None + #fd_to_handlers[local_sock.fileno()] = self + #local_sock.setblocking(False) + #loop.add(local_sock, eventloop.POLL_IN | eventloop.POLL_ERR) + self.last_activity = 0 + self._update_activity() + self._random_mtu_size = [random.randint(POST_MTU_MIN, POST_MTU_MAX) for i in range(1024)] + self._random_mtu_index = 0 + + self._rand_data = b"\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f\x10" * 4 + + def __hash__(self): + # default __hash__ is id / 16 + # we want to eliminate collisions + return id(self) + + @property + def remote_address(self): + return self._remote_address + + def add_local_address(self, addr): + self._client_address.add(addr) + + def get_local_address(self): + return self._client_address.get() + + def _update_activity(self): + # tell the TCP Relay we have activities recently + # else it will think we are inactive and timed out + self._server.update_activity(self) + + def _update_stream(self, stream, status): + # update a stream to a new waiting status + + # check if status is changed + # only update if dirty + dirty = False + if stream == STREAM_DOWN: + if self._downstream_status != status: + self._downstream_status = status + dirty = True + elif stream == STREAM_UP: + if self._upstream_status != status: + self._upstream_status = status + dirty = True + if dirty: + ''' + if self._local_sock: + event = eventloop.POLL_ERR + if self._downstream_status & WAIT_STATUS_WRITING: + event |= eventloop.POLL_OUT + if self._upstream_status & WAIT_STATUS_READING: + event |= eventloop.POLL_IN + self._loop.modify(self._local_sock, event) + ''' + if self._remote_sock: + event = eventloop.POLL_ERR + if self._downstream_status & WAIT_STATUS_READING: + event |= eventloop.POLL_IN + if self._upstream_status & WAIT_STATUS_WRITING: + event |= eventloop.POLL_OUT + self._loop.modify(self._remote_sock, event) + + def _write_to_sock(self, data, sock, addr = None): + # write data to sock + # if only some of the data are written, put remaining in the buffer + # and update the stream to wait for writing + if not data or not sock: + return False + + uncomplete = False + retry = 0 + if sock == self._local_sock: + data = encrypt.encrypt_all(self._password, self._method, 1, data) + if addr is None: + return False + try: + self._server.write_to_server_socket(data, addr) + except (OSError, IOError) as e: + error_no = eventloop.errno_from_exception(e) + uncomplete = True + if error_no in (errno.EAGAIN, errno.EINPROGRESS, + errno.EWOULDBLOCK): + pass + else: + #traceback.print_exc() + shell.print_exception(e) + self.destroy() + return False + else: + try: + l = len(data) + s = sock.send(data) + if s < l: + data = data[s:] + uncomplete = True + except (OSError, IOError) as e: + error_no = eventloop.errno_from_exception(e) + if error_no in (errno.EAGAIN, errno.EINPROGRESS, + errno.EWOULDBLOCK): + uncomplete = True + else: + #logging.error(traceback.extract_stack()) + #traceback.print_exc() + shell.print_exception(e) + self.destroy() + return False + if uncomplete: + if sock == self._local_sock: + self._update_stream(STREAM_DOWN, WAIT_STATUS_WRITING) + elif sock == self._remote_sock: + self._data_to_write_to_remote.append(data) + self._update_stream(STREAM_UP, WAIT_STATUS_WRITING) + else: + logging.error('write_all_to_sock:unknown socket') + else: + if sock == self._local_sock: + if self._sendingqueue.size() > SENDING_WINDOW_SIZE: + self._update_stream(STREAM_DOWN, WAIT_STATUS_WRITING) + else: + self._update_stream(STREAM_DOWN, WAIT_STATUS_READING) + elif sock == self._remote_sock: + self._update_stream(STREAM_UP, WAIT_STATUS_READING) + else: + logging.error('write_all_to_sock:unknown socket') + return True + + def _create_remote_socket(self, ip, port): + addrs = socket.getaddrinfo(ip, port, 0, socket.SOCK_STREAM, socket.SOL_TCP) + if len(addrs) == 0: + raise Exception("getaddrinfo failed for %s:%d" % (ip, port)) + af, socktype, proto, canonname, sa = addrs[0] + if self._forbidden_iplist: + if common.to_str(sa[0]) in self._forbidden_iplist: + raise Exception('IP %s is in forbidden list, reject' % + common.to_str(sa[0])) + remote_sock = socket.socket(af, socktype, proto) + self._remote_sock = remote_sock + + self._fd_to_handlers[remote_sock.fileno()] = self + + remote_sock.setblocking(False) + remote_sock.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1) + return remote_sock + + def _handle_dns_resolved(self, result, error): + if error: + self._log_error(error) + self.destroy() + return + if result: + ip = result[1] + if ip: + + try: + self._stage = STAGE_CONNECTING + remote_addr = ip + remote_port = self._remote_address[1] + logging.info("connect to %s : %d" % (remote_addr, remote_port)) + + remote_sock = self._create_remote_socket(remote_addr, + remote_port) + try: + remote_sock.connect((remote_addr, remote_port)) + except (OSError, IOError) as e: + if eventloop.errno_from_exception(e) in (errno.EINPROGRESS, + errno.EWOULDBLOCK): + pass # always goto here + else: + raise e + + self._loop.add(remote_sock, + eventloop.POLL_ERR | eventloop.POLL_OUT, + self._server) + self._update_stream(STREAM_UP, WAIT_STATUS_READWRITING) + self._update_stream(STREAM_DOWN, WAIT_STATUS_READING) + self._stage = STAGE_STREAM + + addr = self.get_local_address() + + for i in range(2): + rsp_data = self._pack_rsp_data(CMD_RSP_CONNECT_REMOTE, RSP_STATE_CONNECTEDREMOTE) + self._write_to_sock(rsp_data, self._local_sock, addr) + + return + except Exception as e: + shell.print_exception(e) + if self._config['verbose']: + traceback.print_exc() + self.destroy() + + def _on_local_read(self): + # handle all local read events and dispatch them to methods for + # each stage + self._update_activity() + if not self._local_sock: + return + data = None + try: + data = self._local_sock.recv(BUF_SIZE) + except (OSError, IOError) as e: + if eventloop.errno_from_exception(e) in \ + (errno.ETIMEDOUT, errno.EAGAIN, errno.EWOULDBLOCK): + return + if not data: + self.destroy() + return + if not data: + return + self._server.server_transfer_ul += len(data) + #TODO ============================================================ + if self._stage == STAGE_STREAM: + self._write_to_sock(data, self._remote_sock) + return + + def _on_remote_read(self): + # handle all remote read events + self._update_activity() + data = None + try: + data = self._remote_sock.recv(BUF_SIZE) + except (OSError, IOError) as e: + if eventloop.errno_from_exception(e) in \ + (errno.ETIMEDOUT, errno.EAGAIN, errno.EWOULDBLOCK, 10035): #errno.WSAEWOULDBLOCK + return + if not data: + self.destroy() + return + self._server.server_transfer_dl += len(data) + try: + recv_data = data + beg_pos = 0 + max_len = len(recv_data) + while beg_pos < max_len: + if beg_pos + POST_MTU_MAX >= max_len: + split_pos = max_len + else: + split_pos = beg_pos + self._random_mtu_size[self._random_mtu_index] + self._random_mtu_index = (self._random_mtu_index + 1) & 0x3ff + #split_pos = beg_pos + random.randint(POST_MTU_MIN, POST_MTU_MAX) + data = recv_data[beg_pos:split_pos] + beg_pos = split_pos + + pack_id = self._sendingqueue.append(data) + post_data = self._pack_post_data(CMD_POST, pack_id, data) + addr = self.get_local_address() + self._write_to_sock(post_data, self._local_sock, addr) + if pack_id <= DOUBLE_SEND_BEG_IDS: + post_data = self._pack_post_data(CMD_POST, pack_id, data) + self._write_to_sock(post_data, self._local_sock, addr) + + except Exception as e: + shell.print_exception(e) + if self._config['verbose']: + traceback.print_exc() + # TODO use logging when debug completed + self.destroy() + def _on_local_write(self): + # handle local writable event + if self._data_to_write_to_local: + data = b''.join(self._data_to_write_to_local) + self._data_to_write_to_local = [] + self._write_to_sock(data, self._local_sock) + else: + self._update_stream(STREAM_DOWN, WAIT_STATUS_READING) + + def _on_remote_write(self): + # handle remote writable event + self._stage = STAGE_STREAM + if self._data_to_write_to_remote: + data = b''.join(self._data_to_write_to_remote) + self._data_to_write_to_remote = [] + self._write_to_sock(data, self._remote_sock) + else: + self._update_stream(STREAM_UP, WAIT_STATUS_READING) + + def _on_local_error(self): + logging.debug('got local error') + if self._local_sock: + logging.error(eventloop.get_sock_error(self._local_sock)) + self.destroy() + + def _on_remote_error(self): + logging.debug('got remote error') + if self._remote_sock: + logging.error(eventloop.get_sock_error(self._remote_sock)) + self.destroy() + + def _pack_rsp_data(self, cmd, data): + reqid_str = struct.pack(">H", self._request_id) + return b''.join([CMD_VER_STR, common.chr(cmd), reqid_str, data, self._rand_data[:random.randint(0, len(self._rand_data))], reqid_str]) -def client_key(a, b, c, d): - return '%s:%s:%s:%s' % (a, b, c, d) + def _pack_rnd_data(self, data): + length = random.randint(0, len(self._rand_data)) + if length == 0: + return data + elif length == 1: + return b"\x81" + data + elif length < 256: + return b"\x80" + common.chr(length) + self._rand_data[:length - 2] + data + else: + return b"\x82" + struct.pack(">H", length) + self._rand_data[:length - 3] + data + + def _pack_post_data(self, cmd, pack_id, data): + reqid_str = struct.pack(">H", self._request_id) + recv_id = self._recvqueue.get_begin_id() + rsp_data = b''.join([CMD_VER_STR, common.chr(cmd), reqid_str, struct.pack(">I", recv_id), struct.pack(">I", pack_id), data, reqid_str]) + return rsp_data + + def _pack_post_data_64(self, cmd, send_id, pack_id, data): + reqid_str = struct.pack(">H", self._request_id) + recv_id = self._recvqueue.get_begin_id() + rsp_data = b''.join([CMD_VER_STR, common.chr(cmd), reqid_str, struct.pack(">Q", recv_id), struct.pack(">Q", pack_id), data, reqid_str]) + return rsp_data + + def sweep_timeout(self): + logging.info("sweep_timeout") + if self._stage == STAGE_STREAM: + pack_id, missing = self._recvqueue.get_missing_id(0) + logging.info("sweep_timeout %s %s" % (pack_id, missing)) + data = b'' + for pid in missing: + data += struct.pack(">H", pid) + rsp_data = self._pack_post_data(CMD_SYN_STATUS, pack_id, data) + addr = self.get_local_address() + self._write_to_sock(rsp_data, self._local_sock, addr) + + def handle_stream_sync_status(self, addr, cmd, request_id, pack_id, max_send_id, data): + missing_list = [] + while len(data) >= 2: + pid = struct.unpack(">H", data[0:2])[0] + data = data[2:] + missing_list.append(pid) + done_list = [] + self._recvqueue.set_end(max_send_id) + self._sendingqueue.set_finish(pack_id, done_list) + + if self._stage == STAGE_DESTROYED and self._sendingqueue.empty(): + self.destroy_local() + return + + # post CMD_SYN_STATUS + send_id = self._sendingqueue.get_end_id() + post_pack_id, missing = self._recvqueue.get_missing_id(0) + pack_ids_data = b'' + for pid in missing: + pack_ids_data += struct.pack(">H", pid) + + rsp_data = self._pack_rnd_data(self._pack_post_data(CMD_SYN_STATUS, send_id, pack_ids_data)) + self._write_to_sock(rsp_data, self._local_sock, addr) + + send_list = self._sendingqueue.get_data_list(pack_id, missing_list) + for post_pack_id, post_data in send_list: + rsp_data = self._pack_post_data(CMD_POST, post_pack_id, post_data) + self._write_to_sock(rsp_data, self._local_sock, addr) + if post_pack_id <= DOUBLE_SEND_BEG_IDS: + rsp_data = self._pack_post_data(CMD_POST, post_pack_id, post_data) + self._write_to_sock(rsp_data, self._local_sock, addr) + + def handle_client(self, addr, cmd, request_id, data): + self.add_local_address(addr) + if cmd == CMD_DISCONNECT: + rsp_data = self._pack_rsp_data(CMD_DISCONNECT, RSP_STATE_EMPTY) + self._write_to_sock(rsp_data, self._local_sock, addr) + self.destroy() + self.destroy_local() + return + if self._stage == STAGE_INIT: + if cmd == CMD_CONNECT: + self._request_id = request_id + self._stage = STAGE_RSP_ID + return + if self._request_id != request_id: + return + + if self._stage == STAGE_RSP_ID: + if cmd == CMD_CONNECT: + for i in range(2): + rsp_data = self._pack_rsp_data(CMD_RSP_CONNECT, RSP_STATE_CONNECTED) + self._write_to_sock(rsp_data, self._local_sock, addr) + elif cmd == CMD_CONNECT_REMOTE: + local_id = data[0:4] + if self._local_id == local_id: + data = data[4:] + header_result = parse_header(data) + if header_result is None: + return + connecttype, remote_addr, remote_port, header_length = header_result + self._remote_address = (common.to_str(remote_addr), remote_port) + self._stage = STAGE_DNS + self._dns_resolver.resolve(remote_addr, + self._handle_dns_resolved) + logging.info('TCP connect %s:%d from %s:%d' % (remote_addr, remote_port, addr[0], addr[1])) + else: + # ileagal request + rsp_data = self._pack_rsp_data(CMD_DISCONNECT, RSP_STATE_EMPTY) + self._write_to_sock(rsp_data, self._local_sock, addr) + elif self._stage == STAGE_CONNECTING: + if cmd == CMD_CONNECT_REMOTE: + local_id = data[0:4] + if self._local_id == local_id: + for i in range(2): + rsp_data = self._pack_rsp_data(CMD_RSP_CONNECT_REMOTE, RSP_STATE_CONNECTEDREMOTE) + self._write_to_sock(rsp_data, self._local_sock, addr) + else: + # ileagal request + rsp_data = self._pack_rsp_data(CMD_DISCONNECT, RSP_STATE_EMPTY) + self._write_to_sock(rsp_data, self._local_sock, addr) + elif self._stage == STAGE_STREAM: + if len(data) < 4: + # ileagal request + rsp_data = self._pack_rsp_data(CMD_DISCONNECT, RSP_STATE_EMPTY) + self._write_to_sock(rsp_data, self._local_sock, addr) + return + local_id = data[0:4] + if self._local_id != local_id: + # ileagal request + rsp_data = self._pack_rsp_data(CMD_DISCONNECT, RSP_STATE_EMPTY) + self._write_to_sock(rsp_data, self._local_sock, addr) + return + else: + data = data[4:] + if cmd == CMD_CONNECT_REMOTE: + rsp_data = self._pack_rsp_data(CMD_RSP_CONNECT_REMOTE, RSP_STATE_CONNECTEDREMOTE) + self._write_to_sock(rsp_data, self._local_sock, addr) + elif cmd == CMD_POST: + recv_id = struct.unpack(">I", data[0:4])[0] + pack_id = struct.unpack(">I", data[4:8])[0] + self._recvqueue.insert(pack_id, data[8:]) + self._sendingqueue.set_finish(recv_id, []) + elif cmd == CMD_POST_64: + recv_id = struct.unpack(">Q", data[0:8])[0] + pack_id = struct.unpack(">Q", data[8:16])[0] + self._recvqueue.insert(pack_id, data[16:]) + self._sendingqueue.set_finish(recv_id, []) + elif cmd == CMD_DISCONNECT: + rsp_data = self._pack_rsp_data(CMD_DISCONNECT, RSP_STATE_EMPTY) + self._write_to_sock(rsp_data, self._local_sock, addr) + self.destroy() + self.destroy_local() + return + elif cmd == CMD_SYN_STATUS: + pack_id = struct.unpack(">I", data[0:4])[0] + max_send_id = struct.unpack(">I", data[4:8])[0] + data = data[8:] + self.handle_stream_sync_status(addr, cmd, request_id, pack_id, max_send_id, data) + elif cmd == CMD_SYN_STATUS_64: + pack_id = struct.unpack(">Q", data[0:8])[0] + max_send_id = struct.unpack(">Q", data[8:16])[0] + data = data[16:] + self.handle_stream_sync_status(addr, cmd, request_id, pack_id, max_send_id, data) + while self._recvqueue.has_data(): + pack_id, post_data = self._recvqueue.get_data() + self._write_to_sock(post_data, self._remote_sock) + elif self._stage == STAGE_DESTROYED: + local_id = data[0:4] + if self._local_id != local_id: + # ileagal request + rsp_data = self._pack_rsp_data(CMD_DISCONNECT, RSP_STATE_EMPTY) + self._write_to_sock(rsp_data, self._local_sock, addr) + return + else: + data = data[4:] + if cmd == CMD_SYN_STATUS: + pack_id = struct.unpack(">I", data[0:4])[0] + max_send_id = struct.unpack(">I", data[4:8])[0] + data = data[8:] + self.handle_stream_sync_status(addr, cmd, request_id, pack_id, max_send_id, data) + elif cmd == CMD_SYN_STATUS_64: + pack_id = struct.unpack(">Q", data[0:8])[0] + max_send_id = struct.unpack(">Q", data[8:16])[0] + data = data[16:] + self.handle_stream_sync_status(addr, cmd, request_id, pack_id, max_send_id, data) + + def handle_event(self, sock, event): + # handle all events in this handler and dispatch them to methods + if self._stage == STAGE_DESTROYED: + logging.debug('ignore handle_event: destroyed') + return + # order is important + if sock == self._remote_sock: + if event & eventloop.POLL_ERR: + self._on_remote_error() + if self._stage == STAGE_DESTROYED: + return + if event & (eventloop.POLL_IN | eventloop.POLL_HUP): + self._on_remote_read() + if self._stage == STAGE_DESTROYED: + return + if event & eventloop.POLL_OUT: + self._on_remote_write() + elif sock == self._local_sock: + if event & eventloop.POLL_ERR: + self._on_local_error() + if self._stage == STAGE_DESTROYED: + return + if event & (eventloop.POLL_IN | eventloop.POLL_HUP): + self._on_local_read() + if self._stage == STAGE_DESTROYED: + return + if event & eventloop.POLL_OUT: + self._on_local_write() + else: + logging.warn('unknown socket') + + def _log_error(self, e): + logging.error('%s when handling connection from %s' % + (e, self._client_address.keys())) + + def destroy(self): + # destroy the handler and release any resources + # promises: + # 1. destroy won't make another destroy() call inside + # 2. destroy releases resources so it prevents future call to destroy + # 3. destroy won't raise any exceptions + # if any of the promises are broken, it indicates a bug has been + # introduced! mostly likely memory leaks, etc + #logging.info('tcp destroy called') + if self._stage == STAGE_DESTROYED: + # this couldn't happen + logging.debug('already destroyed') + return + self._stage = STAGE_DESTROYED + if self._remote_address: + logging.debug('destroy: %s:%d' % + self._remote_address) + else: + logging.debug('destroy') + if self._remote_sock: + logging.debug('destroying remote') + self._loop.remove(self._remote_sock) + try: + del self._fd_to_handlers[self._remote_sock.fileno()] + except Exception as e: + pass + self._remote_sock.close() + self._remote_sock = None + if self._sendingqueue.empty(): + self.destroy_local() + self._dns_resolver.remove_callback(self._handle_dns_resolved) + + def destroy_local(self): + if self._local_sock: + logging.debug('disconnect local') + rsp_data = self._pack_rsp_data(CMD_DISCONNECT, RSP_STATE_EMPTY) + addr = None + addr = self.get_local_address() + self._write_to_sock(rsp_data, self._local_sock, addr) + self._local_sock = None + try: + del self._reqid_to_handlers[self._request_id] + except Exception as e: + pass + + self._server.remove_handler(self) + +def client_key(source_addr, server_af): + # notice this is server af, not dest af + return '%s:%s:%d' % (source_addr[0], source_addr[1], server_af) class UDPRelay(object): - def __init__(self, config, dns_resolver, is_local): + def __init__(self, config, dns_resolver, is_local, stat_callback=None): self._config = config if is_local: self._listen_addr = config['local_address'] @@ -94,7 +876,7 @@ def __init__(self, config, dns_resolver, is_local): self._remote_addr = None self._remote_port = None self._dns_resolver = dns_resolver - self._password = config['password'] + self._password = common.to_bytes(config['password']) self._method = config['method'] self._timeout = config['timeout'] self._is_local = is_local @@ -102,10 +884,22 @@ def __init__(self, config, dns_resolver, is_local): close_callback=self._close_client) self._client_fd_to_server_addr = \ lru_cache.LRUCache(timeout=config['timeout']) + self._dns_cache = lru_cache.LRUCache(timeout=300) self._eventloop = None self._closed = False - self._last_time = time.time() + self.server_transfer_ul = 0 + self.server_transfer_dl = 0 + self._sockets = set() + self._fd_to_handlers = {} + self._reqid_to_hd = {} + self._data_to_write_to_server_socket = [] + + self._timeouts = [] # a list for all the handlers + # we trim the timeouts once a while + self._timeout_offset = 0 # last checked position for timeout + self._handler_to_timeouts = {} # key: handler value: index in timeouts + if 'forbidden_ip' in config: self._forbidden_iplist = config['forbidden_ip'] else: @@ -120,7 +914,10 @@ def __init__(self, config, dns_resolver, is_local): server_socket = socket.socket(af, socktype, proto) server_socket.bind((self._listen_addr, self._listen_port)) server_socket.setblocking(False) + server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 1024 * 32) + server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 1024 * 32) self._server_socket = server_socket + self._stat_callback = stat_callback def _get_a_server(self): server = self._config['server'] @@ -141,11 +938,53 @@ def _close_client(self, client): # just an address pass + def _pre_parse_udp_header(self, data): + if data is None: + return + datatype = common.ord(data[0]) + if datatype == 0x8: + if len(data) >= 8: + crc = binascii.crc32(data) & 0xffffffff + if crc != 0xffffffff: + logging.warn('uncorrect CRC32, maybe wrong password or ' + 'encryption method') + return None + cmd = common.ord(data[1]) + request_id = struct.unpack('>H', data[2:4])[0] + data = data[4:-4] + return (cmd, request_id, data) + elif len(data) >= 6 and common.ord(data[1]) == 0x0: + crc = binascii.crc32(data) & 0xffffffff + if crc != 0xffffffff: + logging.warn('uncorrect CRC32, maybe wrong password or ' + 'encryption method') + return None + cmd = common.ord(data[1]) + data = data[2:-4] + return (cmd, 0, data) + else: + logging.warn('header too short, maybe wrong password or ' + 'encryption method') + return None + return data + + def _pack_rsp_data(self, cmd, request_id, data): + _rand_data = b"123456789abcdefghijklmnopqrstuvwxyz" * 2 + reqid_str = struct.pack(">H", request_id) + return b''.join([CMD_VER_STR, common.chr(cmd), reqid_str, data, _rand_data[:random.randint(0, len(_rand_data))], reqid_str]) + + def _handel_protocol_error(self, client_address, ogn_data): + #raise Exception('can not parse header') + logging.warn("Protocol ERROR, UDP ogn data %s from %s:%d" % (binascii.hexlify(ogn_data), client_address[0], client_address[1])) + def _handle_server(self): server = self._server_socket data, r_addr = server.recvfrom(BUF_SIZE) + ogn_data = data if not data: logging.debug('UDP handle_server: data is empty') + if self._stat_callback: + self._stat_callback(self._listen_port, len(data)) if self._is_local: frag = common.ord(data[2]) if frag != 0: @@ -159,39 +998,129 @@ def _handle_server(self): if not data: logging.debug('UDP handle_server: data is empty after decrypt') return - header_result = parse_header(data) + + #logging.info("UDP data %s" % (binascii.hexlify(data),)) + if not self._is_local: + data = pre_parse_header(data) + + data = self._pre_parse_udp_header(data) + if data is None: + return + + if type(data) is tuple: + #(cmd, request_id, data) + #logging.info("UDP data %d %d %s" % (data[0], data[1], binascii.hexlify(data[2]))) + try: + if data[0] == 0: + if len(data[2]) >= 4: + for i in range(64): + req_id = random.randint(1, 65535) + if req_id not in self._reqid_to_hd: + break + if req_id in self._reqid_to_hd: + for i in range(64): + req_id = random.randint(1, 65535) + if type(self._reqid_to_hd[req_id]) is tuple: + break + # return req id + self._reqid_to_hd[req_id] = (data[2][0:4], None) + rsp_data = self._pack_rsp_data(CMD_RSP_CONNECT, req_id, RSP_STATE_CONNECTED) + data_to_send = encrypt.encrypt_all(self._password, self._method, 1, rsp_data) + self.write_to_server_socket(data_to_send, r_addr) + elif data[0] == CMD_CONNECT_REMOTE: + if len(data[2]) > 4 and data[1] in self._reqid_to_hd: + # create + if type(self._reqid_to_hd[data[1]]) is tuple: + if data[2][0:4] == self._reqid_to_hd[data[1]][0]: + handle = TCPRelayHandler(self, self._reqid_to_hd, self._fd_to_handlers, + self._eventloop, self._server_socket, + self._reqid_to_hd[data[1]][0], self._reqid_to_hd[data[1]][1], + self._config, self._dns_resolver, self._is_local) + self._reqid_to_hd[data[1]] = handle + handle.handle_client(r_addr, CMD_CONNECT, data[1], data[2]) + handle.handle_client(r_addr, *data) + self.update_activity(handle) + else: + # disconnect + rsp_data = self._pack_rsp_data(CMD_DISCONNECT, data[1], RSP_STATE_EMPTY) + data_to_send = encrypt.encrypt_all(self._password, self._method, 1, rsp_data) + self.write_to_server_socket(data_to_send, r_addr) + else: + self.update_activity(self._reqid_to_hd[data[1]]) + self._reqid_to_hd[data[1]].handle_client(r_addr, *data) + else: + # disconnect + rsp_data = self._pack_rsp_data(CMD_DISCONNECT, data[1], RSP_STATE_EMPTY) + data_to_send = encrypt.encrypt_all(self._password, self._method, 1, rsp_data) + self.write_to_server_socket(data_to_send, r_addr) + elif data[0] > CMD_CONNECT_REMOTE and data[0] <= CMD_DISCONNECT: + if data[1] in self._reqid_to_hd: + if type(self._reqid_to_hd[data[1]]) is tuple: + pass + else: + self.update_activity(self._reqid_to_hd[data[1]]) + self._reqid_to_hd[data[1]].handle_client(r_addr, *data) + else: + # disconnect + rsp_data = self._pack_rsp_data(CMD_DISCONNECT, data[1], RSP_STATE_EMPTY) + data_to_send = encrypt.encrypt_all(self._password, self._method, 1, rsp_data) + self.write_to_server_socket(data_to_send, r_addr) + return + except Exception as e: + trace = traceback.format_exc() + logging.error(trace) + return + + try: + header_result = parse_header(data) + except: + self._handel_protocol_error(r_addr, ogn_data) + return + if header_result is None: + self._handel_protocol_error(r_addr, ogn_data) return - addrtype, dest_addr, dest_port, header_length = header_result + connecttype, dest_addr, dest_port, header_length = header_result if self._is_local: server_addr, server_port = self._get_a_server() else: server_addr, server_port = dest_addr, dest_port - key = client_key(r_addr[0], r_addr[1], dest_addr, dest_port) - client = self._cache.get(key, None) - if not client: - # TODO async getaddrinfo + addrs = self._dns_cache.get(server_addr, None) + if addrs is None: addrs = socket.getaddrinfo(server_addr, server_port, 0, socket.SOCK_DGRAM, socket.SOL_UDP) - if addrs: - af, socktype, proto, canonname, sa = addrs[0] - if self._forbidden_iplist: - if common.to_str(sa[0]) in self._forbidden_iplist: - logging.debug('IP %s is in forbidden list, drop' % - common.to_str(sa[0])) - # drop - return - client = socket.socket(af, socktype, proto) - client.setblocking(False) - self._cache[key] = client - self._client_fd_to_server_addr[client.fileno()] = r_addr - else: + if not addrs: # drop return + else: + self._dns_cache[server_addr] = addrs + + af, socktype, proto, canonname, sa = addrs[0] + key = client_key(r_addr, af) + client = self._cache.get(key, None) + if not client: + # TODO async getaddrinfo + if self._forbidden_iplist: + if common.to_str(sa[0]) in self._forbidden_iplist: + logging.debug('IP %s is in forbidden list, drop' % + common.to_str(sa[0])) + # drop + return + client = socket.socket(af, socktype, proto) + client.setblocking(False) + self._cache[key] = client + self._client_fd_to_server_addr[client.fileno()] = r_addr + self._sockets.add(client.fileno()) - self._eventloop.add(client, eventloop.POLL_IN) + self._eventloop.add(client, eventloop.POLL_IN, self) + + logging.debug('UDP port %5d sockets %d' % (self._listen_port, len(self._sockets))) + + logging.info('UDP data to %s:%d from %s:%d' % + (common.to_str(server_addr), server_port, + r_addr[0], r_addr[1])) if self._is_local: data = encrypt.encrypt_all(self._password, self._method, 1, data) @@ -202,6 +1131,7 @@ def _handle_server(self): if not data: return try: + #logging.info('UDP handle_server sendto %s:%d %d bytes' % (common.to_str(server_addr), server_port, len(data))) client.sendto(data, (server_addr, server_port)) except IOError as e: err = eventloop.errno_from_exception(e) @@ -215,6 +1145,8 @@ def _handle_client(self, sock): if not data: logging.debug('UDP handle_client: data is empty') return + if self._stat_callback: + self._stat_callback(self._listen_port, len(data)) if not self._is_local: addrlen = len(r_addr[0]) if addrlen > 255: @@ -233,50 +1165,168 @@ def _handle_client(self, sock): header_result = parse_header(data) if header_result is None: return - # addrtype, dest_addr, dest_port, header_length = header_result + #connecttype, dest_addr, dest_port, header_length = header_result + #logging.debug('UDP handle_client %s:%d to %s:%d' % (common.to_str(r_addr[0]), r_addr[1], dest_addr, dest_port)) + response = b'\x00\x00\x00' + data client_addr = self._client_fd_to_server_addr.get(sock.fileno()) if client_addr: - self._server_socket.sendto(response, client_addr) + self.write_to_server_socket(response, client_addr) else: # this packet is from somewhere else we know # simply drop that packet pass + def write_to_server_socket(self, data, addr): + #self._server_socket.sendto(data, addr) + #''' + uncomplete = False + retry = 0 + try: + #""" + #if self._data_to_write_to_server_socket: + # self._data_to_write_to_server_socket.append([(data, addr), 0]) + #else: + self._server_socket.sendto(data, addr) + data = None + while self._data_to_write_to_server_socket: + data_buf = self._data_to_write_to_server_socket[0] + retry = data_buf[1] + 1 + del self._data_to_write_to_server_socket[0] + data, addr = data_buf[0] + self._server_socket.sendto(data, addr) + #""" + except (OSError, IOError) as e: + error_no = eventloop.errno_from_exception(e) + uncomplete = True + if error_no in (errno.EWOULDBLOCK,): + pass + else: + shell.print_exception(e) + return False + #if uncomplete and data is not None and retry < 3: + # self._data_to_write_to_server_socket.append([(data, addr), retry]) + #''' + def add_to_loop(self, loop): if self._eventloop: raise Exception('already add to loop') if self._closed: raise Exception('already closed') self._eventloop = loop - loop.add_handler(self._handle_events) server_socket = self._server_socket self._eventloop.add(server_socket, - eventloop.POLL_IN | eventloop.POLL_ERR) + eventloop.POLL_IN | eventloop.POLL_ERR, self) + loop.add_periodic(self.handle_periodic) + + def remove_handler(self, handler): + index = self._handler_to_timeouts.get(hash(handler), -1) + if index >= 0: + # delete is O(n), so we just set it to None + self._timeouts[index] = None + del self._handler_to_timeouts[hash(handler)] + + def update_activity(self, handler): + # set handler to active + now = int(time.time()) + if now - handler.last_activity < eventloop.TIMEOUT_PRECISION: + # thus we can lower timeout modification frequency + return + handler.last_activity = now + index = self._handler_to_timeouts.get(hash(handler), -1) + if index >= 0: + # delete is O(n), so we just set it to None + self._timeouts[index] = None + length = len(self._timeouts) + self._timeouts.append(handler) + self._handler_to_timeouts[hash(handler)] = length + + def _sweep_timeout(self): + # tornado's timeout memory management is more flexible than we need + # we just need a sorted last_activity queue and it's faster than heapq + # in fact we can do O(1) insertion/remove so we invent our own + if self._timeouts: + logging.log(shell.VERBOSE_LEVEL, 'sweeping timeouts') + now = time.time() + length = len(self._timeouts) + pos = self._timeout_offset + while pos < length: + handler = self._timeouts[pos] + if handler: + if now - handler.last_activity < self._timeout: + break + else: + if handler.remote_address: + logging.warn('timed out: %s:%d' % + handler.remote_address) + else: + logging.warn('timed out') + handler.destroy() + handler.destroy_local() + self._timeouts[pos] = None # free memory + pos += 1 + else: + pos += 1 + if pos > TIMEOUTS_CLEAN_SIZE and pos > length >> 1: + # clean up the timeout queue when it gets larger than half + # of the queue + self._timeouts = self._timeouts[pos:] + for key in self._handler_to_timeouts: + self._handler_to_timeouts[key] -= pos + pos = 0 + self._timeout_offset = pos - def _handle_events(self, events): - for sock, fd, event in events: - if sock == self._server_socket: - if event & eventloop.POLL_ERR: - logging.error('UDP server_socket err') + def handle_event(self, sock, fd, event): + if sock == self._server_socket: + if event & eventloop.POLL_ERR: + logging.error('UDP server_socket err') + try: self._handle_server() - elif sock and (fd in self._sockets): - if event & eventloop.POLL_ERR: - logging.error('UDP client_socket err') + except Exception as e: + shell.print_exception(e) + if self._config['verbose']: + traceback.print_exc() + elif sock and (fd in self._sockets): + if event & eventloop.POLL_ERR: + logging.error('UDP client_socket err') + try: self._handle_client(sock) - now = time.time() - if now - self._last_time > 3: - self._cache.sweep() - self._client_fd_to_server_addr.sweep() - self._last_time = now + except Exception as e: + shell.print_exception(e) + if self._config['verbose']: + traceback.print_exc() + else: + if sock: + handler = self._fd_to_handlers.get(fd, None) + if handler: + handler.handle_event(sock, event) + else: + logging.warn('poll removed fd') + + def handle_periodic(self): if self._closed: - self._server_socket.close() - for sock in self._sockets: - sock.close() - self._eventloop.remove_handler(self._handle_events) + if self._server_socket: + self._server_socket.close() + self._server_socket = None + for sock in self._sockets: + sock.close() + logging.info('closed UDP port %d', self._listen_port) + before_sweep_size = len(self._sockets) + self._cache.sweep() + self._dns_cache.sweep() + if before_sweep_size != len(self._sockets): + logging.debug('UDP port %5d sockets %d' % (self._listen_port, len(self._sockets))) + self._client_fd_to_server_addr.sweep() + self._sweep_timeout() def close(self, next_tick=False): + logging.debug('UDP close') self._closed = True if not next_tick: + if self._eventloop: + self._eventloop.remove_periodic(self.handle_periodic) + self._eventloop.remove(self._server_socket) self._server_socket.close() + for client in list(self._cache.values()): + client.close() diff --git a/tests/jenkins.sh b/tests/jenkins.sh index 71d5b1ca..ea5c1630 100755 --- a/tests/jenkins.sh +++ b/tests/jenkins.sh @@ -69,7 +69,7 @@ if [ -f /proc/sys/net/ipv4/tcp_fastopen ] ; then fi run_test tests/test_large_file.sh - +run_test tests/test_udp_src.sh run_test tests/test_command.sh coverage combine && coverage report --include=shadowsocks/* diff --git a/tests/test_udp_src.py b/tests/test_udp_src.py new file mode 100644 index 00000000..e8fa5057 --- /dev/null +++ b/tests/test_udp_src.py @@ -0,0 +1,83 @@ +#!/usr/bin/python + +import socket +import socks + + +SERVER_IP = '127.0.0.1' +SERVER_PORT = 1081 + + +if __name__ == '__main__': + # Test 1: same source port IPv4 + sock_out = socks.socksocket(socket.AF_INET, socket.SOCK_DGRAM, + socket.SOL_UDP) + sock_out.set_proxy(socks.SOCKS5, SERVER_IP, SERVER_PORT) + sock_out.bind(('127.0.0.1', 9000)) + + sock_in1 = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, + socket.SOL_UDP) + sock_in2 = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, + socket.SOL_UDP) + + sock_in1.bind(('127.0.0.1', 9001)) + sock_in2.bind(('127.0.0.1', 9002)) + + sock_out.sendto(b'data', ('127.0.0.1', 9001)) + result1 = sock_in1.recvfrom(8) + + sock_out.sendto(b'data', ('127.0.0.1', 9002)) + result2 = sock_in2.recvfrom(8) + + sock_out.close() + sock_in1.close() + sock_in2.close() + + # make sure they're from the same source port + assert result1 == result2 + + # Test 2: same source port IPv6 + # try again from the same port but IPv6 + sock_out = socks.socksocket(socket.AF_INET, socket.SOCK_DGRAM, + socket.SOL_UDP) + sock_out.set_proxy(socks.SOCKS5, SERVER_IP, SERVER_PORT) + sock_out.bind(('127.0.0.1', 9000)) + + sock_in1 = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM, + socket.SOL_UDP) + sock_in2 = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM, + socket.SOL_UDP) + + sock_in1.bind(('::1', 9001)) + sock_in2.bind(('::1', 9002)) + + sock_out.sendto(b'data', ('::1', 9001)) + result1 = sock_in1.recvfrom(8) + + sock_out.sendto(b'data', ('::1', 9002)) + result2 = sock_in2.recvfrom(8) + + sock_out.close() + sock_in1.close() + sock_in2.close() + + # make sure they're from the same source port + assert result1 == result2 + + # Test 3: different source ports IPv6 + sock_out = socks.socksocket(socket.AF_INET, socket.SOCK_DGRAM, + socket.SOL_UDP) + sock_out.set_proxy(socks.SOCKS5, SERVER_IP, SERVER_PORT) + sock_out.bind(('127.0.0.1', 9003)) + + sock_in1 = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM, + socket.SOL_UDP) + sock_in1.bind(('::1', 9001)) + sock_out.sendto(b'data', ('::1', 9001)) + result3 = sock_in1.recvfrom(8) + + # make sure they're from different source ports + assert result1 != result3 + + sock_out.close() + sock_in1.close() diff --git a/tests/test_udp_src.sh b/tests/test_udp_src.sh new file mode 100755 index 00000000..d356581c --- /dev/null +++ b/tests/test_udp_src.sh @@ -0,0 +1,23 @@ +#!/bin/bash + +PYTHON="coverage run -p -a" + +mkdir -p tmp + +$PYTHON shadowsocks/local.py -c tests/aes.json -v & +LOCAL=$! + +$PYTHON shadowsocks/server.py -c tests/aes.json --forbidden-ip "" -v & +SERVER=$! + +sleep 3 + +python tests/test_udp_src.py +r=$? + +kill -s SIGINT $LOCAL +kill -s SIGINT $SERVER + +sleep 2 + +exit $r diff --git a/utils/fail2ban/shadowsocks.conf b/utils/fail2ban/shadowsocks.conf new file mode 100644 index 00000000..9b1c7ec7 --- /dev/null +++ b/utils/fail2ban/shadowsocks.conf @@ -0,0 +1,5 @@ +[Definition] + +_daemon = shadowsocks + +failregex = ^\s+ERROR\s+can not parse header when handling connection from :\d+$