Home | History | Annotate | Download | only in tests
      1 #!/usr/bin/env python
      2 # -*- coding: utf-8 -*-
      3 #
      4 #  Project                     ___| | | |  _ \| |
      5 #                             / __| | | | |_) | |
      6 #                            | (__| |_| |  _ <| |___
      7 #                             \___|\___/|_| \_\_____|
      8 #
      9 # Copyright (C) 2017, Daniel Stenberg, <daniel (at] haxx.se>, et al.
     10 #
     11 # This software is licensed as described in the file COPYING, which
     12 # you should have received as part of this distribution. The terms
     13 # are also available at https://curl.haxx.se/docs/copyright.html.
     14 #
     15 # You may opt to use, copy, modify, merge, publish, distribute and/or sell
     16 # copies of the Software, and permit persons to whom the Software is
     17 # furnished to do so, under the terms of the COPYING file.
     18 #
     19 # This software is distributed on an "AS IS" basis, WITHOUT WARRANTY OF ANY
     20 # KIND, either express or implied.
     21 #
     22 """Server for testing SMB"""
     23 
     24 from __future__ import (absolute_import, division, print_function)
     25 # unicode_literals)
     26 import argparse
     27 import ConfigParser
     28 import os
     29 import sys
     30 import logging
     31 import tempfile
     32 
     33 # Import our curl test data helper
     34 import curl_test_data
     35 
     36 # This saves us having to set up the PYTHONPATH explicitly
     37 deps_dir = os.path.join(os.path.dirname(__file__), "python_dependencies")
     38 sys.path.append(deps_dir)
     39 from impacket import smbserver as imp_smbserver
     40 from impacket import smb as imp_smb
     41 from impacket.nt_errors import (STATUS_ACCESS_DENIED, STATUS_SUCCESS,
     42                                 STATUS_NO_SUCH_FILE)
     43 
     44 log = logging.getLogger(__name__)
     45 SERVER_MAGIC = "SERVER_MAGIC"
     46 TESTS_MAGIC = "TESTS_MAGIC"
     47 VERIFIED_REQ = "verifiedserver"
     48 VERIFIED_RSP = b"WE ROOLZ: {pid}\n"
     49 
     50 
     51 def smbserver(options):
     52     """Start up a TCP SMB server that serves forever
     53 
     54     """
     55     if options.pidfile:
     56         pid = os.getpid()
     57         with open(options.pidfile, "w") as f:
     58             f.write("{0}".format(pid))
     59 
     60     # Here we write a mini config for the server
     61     smb_config = ConfigParser.ConfigParser()
     62     smb_config.add_section("global")
     63     smb_config.set("global", "server_name", "SERVICE")
     64     smb_config.set("global", "server_os", "UNIX")
     65     smb_config.set("global", "server_domain", "WORKGROUP")
     66     smb_config.set("global", "log_file", "")
     67     smb_config.set("global", "credentials_file", "")
     68 
     69     # We need a share which allows us to test that the server is running
     70     smb_config.add_section("SERVER")
     71     smb_config.set("SERVER", "comment", "server function")
     72     smb_config.set("SERVER", "read only", "yes")
     73     smb_config.set("SERVER", "share type", "0")
     74     smb_config.set("SERVER", "path", SERVER_MAGIC)
     75 
     76     # Have a share for tests.  These files will be autogenerated from the
     77     # test input.
     78     smb_config.add_section("TESTS")
     79     smb_config.set("TESTS", "comment", "tests")
     80     smb_config.set("TESTS", "read only", "yes")
     81     smb_config.set("TESTS", "share type", "0")
     82     smb_config.set("TESTS", "path", TESTS_MAGIC)
     83 
     84     if not options.srcdir or not os.path.isdir(options.srcdir):
     85         raise ScriptException("--srcdir is mandatory")
     86 
     87     test_data_dir = os.path.join(options.srcdir, "data")
     88 
     89     smb_server = TestSmbServer(("127.0.0.1", options.port),
     90                                config_parser=smb_config,
     91                                test_data_directory=test_data_dir)
     92     log.info("[SMB] setting up SMB server on port %s", options.port)
     93     smb_server.processConfigFile()
     94     smb_server.serve_forever()
     95     return 0
     96 
     97 
     98 class TestSmbServer(imp_smbserver.SMBSERVER):
     99     """
    100     Test server for SMB which subclasses the impacket SMBSERVER and provides
    101     test functionality.
    102     """
    103 
    104     def __init__(self,
    105                  address,
    106                  config_parser=None,
    107                  test_data_directory=None):
    108         imp_smbserver.SMBSERVER.__init__(self,
    109                                          address,
    110                                          config_parser=config_parser)
    111 
    112         # Set up a test data object so we can get test data later.
    113         self.ctd = curl_test_data.TestData(test_data_directory)
    114 
    115         # Override smbComNtCreateAndX so we can pretend to have files which
    116         # don't exist.
    117         self.hookSmbCommand(imp_smb.SMB.SMB_COM_NT_CREATE_ANDX,
    118                             self.create_and_x)
    119 
    120     def create_and_x(self, conn_id, smb_server, smb_command, recv_packet):
    121         """
    122         Our version of smbComNtCreateAndX looks for special test files and
    123         fools the rest of the framework into opening them as if they were
    124         normal files.
    125         """
    126         conn_data = smb_server.getConnectionData(conn_id)
    127 
    128         # Wrap processing in a try block which allows us to throw SmbException
    129         # to control the flow.
    130         try:
    131             ncax_parms = imp_smb.SMBNtCreateAndX_Parameters(
    132                 smb_command["Parameters"])
    133 
    134             path = self.get_share_path(conn_data,
    135                                        ncax_parms["RootFid"],
    136                                        recv_packet["Tid"])
    137             log.info("[SMB] Requested share path: %s", path)
    138 
    139             disposition = ncax_parms["Disposition"]
    140             log.debug("[SMB] Requested disposition: %s", disposition)
    141 
    142             # Currently we only support reading files.
    143             if disposition != imp_smb.FILE_OPEN:
    144                 raise SmbException(STATUS_ACCESS_DENIED,
    145                                    "Only support reading files")
    146 
    147             # Check to see if the path we were given is actually a
    148             # magic path which needs generating on the fly.
    149             if path not in [SERVER_MAGIC, TESTS_MAGIC]:
    150                 # Pass the command onto the original handler.
    151                 return imp_smbserver.SMBCommands.smbComNtCreateAndX(conn_id,
    152                                                                     smb_server,
    153                                                                     smb_command,
    154                                                                     recv_packet)
    155 
    156             flags2 = recv_packet["Flags2"]
    157             ncax_data = imp_smb.SMBNtCreateAndX_Data(flags=flags2,
    158                                                      data=smb_command[
    159                                                          "Data"])
    160             requested_file = imp_smbserver.decodeSMBString(
    161                 flags2,
    162                 ncax_data["FileName"])
    163             log.debug("[SMB] User requested file '%s'", requested_file)
    164 
    165             if path == SERVER_MAGIC:
    166                 fid, full_path = self.get_server_path(requested_file)
    167             else:
    168                 assert (path == TESTS_MAGIC)
    169                 fid, full_path = self.get_test_path(requested_file)
    170 
    171             resp_parms = imp_smb.SMBNtCreateAndXResponse_Parameters()
    172             resp_data = ""
    173 
    174             # Simple way to generate a fid
    175             if len(conn_data["OpenedFiles"]) == 0:
    176                 fakefid = 1
    177             else:
    178                 fakefid = conn_data["OpenedFiles"].keys()[-1] + 1
    179             resp_parms["Fid"] = fakefid
    180             resp_parms["CreateAction"] = disposition
    181 
    182             if os.path.isdir(path):
    183                 resp_parms[
    184                     "FileAttributes"] = imp_smb.SMB_FILE_ATTRIBUTE_DIRECTORY
    185                 resp_parms["IsDirectory"] = 1
    186             else:
    187                 resp_parms["IsDirectory"] = 0
    188                 resp_parms["FileAttributes"] = ncax_parms["FileAttributes"]
    189 
    190             # Get this file's information
    191             resp_info, error_code = imp_smbserver.queryPathInformation(
    192                 "", full_path, level=imp_smb.SMB_QUERY_FILE_ALL_INFO)
    193 
    194             if error_code != STATUS_SUCCESS:
    195                 raise SmbException(error_code, "Failed to query path info")
    196 
    197             resp_parms["CreateTime"] = resp_info["CreationTime"]
    198             resp_parms["LastAccessTime"] = resp_info[
    199                 "LastAccessTime"]
    200             resp_parms["LastWriteTime"] = resp_info["LastWriteTime"]
    201             resp_parms["LastChangeTime"] = resp_info[
    202                 "LastChangeTime"]
    203             resp_parms["FileAttributes"] = resp_info[
    204                 "ExtFileAttributes"]
    205             resp_parms["AllocationSize"] = resp_info[
    206                 "AllocationSize"]
    207             resp_parms["EndOfFile"] = resp_info["EndOfFile"]
    208 
    209             # Let's store the fid for the connection
    210             # smbServer.log("Create file %s, mode:0x%x" % (pathName, mode))
    211             conn_data["OpenedFiles"][fakefid] = {}
    212             conn_data["OpenedFiles"][fakefid]["FileHandle"] = fid
    213             conn_data["OpenedFiles"][fakefid]["FileName"] = path
    214             conn_data["OpenedFiles"][fakefid]["DeleteOnClose"] = False
    215 
    216         except SmbException as s:
    217             log.debug("[SMB] SmbException hit: %s", s)
    218             error_code = s.error_code
    219             resp_parms = ""
    220             resp_data = ""
    221 
    222         resp_cmd = imp_smb.SMBCommand(imp_smb.SMB.SMB_COM_NT_CREATE_ANDX)
    223         resp_cmd["Parameters"] = resp_parms
    224         resp_cmd["Data"] = resp_data
    225         smb_server.setConnectionData(conn_id, conn_data)
    226 
    227         return [resp_cmd], None, error_code
    228 
    229     def get_share_path(self, conn_data, root_fid, tid):
    230         conn_shares = conn_data["ConnectedShares"]
    231 
    232         if tid in conn_shares:
    233             if root_fid > 0:
    234                 # If we have a rootFid, the path is relative to that fid
    235                 path = conn_data["OpenedFiles"][root_fid]["FileName"]
    236                 log.debug("RootFid present %s!" % path)
    237             else:
    238                 if "path" in conn_shares[tid]:
    239                     path = conn_shares[tid]["path"]
    240                 else:
    241                     raise SmbException(STATUS_ACCESS_DENIED,
    242                                        "Connection share had no path")
    243         else:
    244             raise SmbException(imp_smbserver.STATUS_SMB_BAD_TID,
    245                                "TID was invalid")
    246 
    247         return path
    248 
    249     def get_server_path(self, requested_filename):
    250         log.debug("[SMB] Get server path '%s'", requested_filename)
    251 
    252         if requested_filename not in [VERIFIED_REQ]:
    253             raise SmbException(STATUS_NO_SUCH_FILE, "Couldn't find the file")
    254 
    255         fid, filename = tempfile.mkstemp()
    256         log.debug("[SMB] Created %s (%d) for storing '%s'",
    257                   filename, fid, requested_filename)
    258 
    259         contents = ""
    260 
    261         if requested_filename == VERIFIED_REQ:
    262             log.debug("[SMB] Verifying server is alive")
    263             contents = VERIFIED_RSP.format(pid=os.getpid())
    264 
    265         self.write_to_fid(fid, contents)
    266         return fid, filename
    267 
    268     def write_to_fid(self, fid, contents):
    269         # Write the contents to file descriptor
    270         os.write(fid, contents)
    271         os.fsync(fid)
    272 
    273         # Rewind the file to the beginning so a read gets us the contents
    274         os.lseek(fid, 0, os.SEEK_SET)
    275 
    276     def get_test_path(self, requested_filename):
    277         log.info("[SMB] Get reply data from 'test%s'", requested_filename)
    278 
    279         fid, filename = tempfile.mkstemp()
    280         log.debug("[SMB] Created %s (%d) for storing test '%s'",
    281                   filename, fid, requested_filename)
    282 
    283         try:
    284             contents = self.ctd.get_test_data(requested_filename)
    285             self.write_to_fid(fid, contents)
    286             return fid, filename
    287 
    288         except Exception:
    289             log.exception("Failed to make test file")
    290             raise SmbException(STATUS_NO_SUCH_FILE, "Failed to make test file")
    291 
    292 
    293 class SmbException(Exception):
    294     def __init__(self, error_code, error_message):
    295         super(SmbException, self).__init__(error_message)
    296         self.error_code = error_code
    297 
    298 
    299 class ScriptRC(object):
    300     """Enum for script return codes"""
    301     SUCCESS = 0
    302     FAILURE = 1
    303     EXCEPTION = 2
    304 
    305 
    306 class ScriptException(Exception):
    307     pass
    308 
    309 
    310 def get_options():
    311     parser = argparse.ArgumentParser()
    312 
    313     parser.add_argument("--port", action="store", default=9017,
    314                       type=int, help="port to listen on")
    315     parser.add_argument("--verbose", action="store", type=int, default=0,
    316                         help="verbose output")
    317     parser.add_argument("--pidfile", action="store",
    318                         help="file name for the PID")
    319     parser.add_argument("--logfile", action="store",
    320                         help="file name for the log")
    321     parser.add_argument("--srcdir", action="store", help="test directory")
    322     parser.add_argument("--id", action="store", help="server ID")
    323     parser.add_argument("--ipv4", action="store_true", default=0,
    324                         help="IPv4 flag")
    325 
    326     return parser.parse_args()
    327 
    328 
    329 def setup_logging(options):
    330     """
    331     Set up logging from the command line options
    332     """
    333     root_logger = logging.getLogger()
    334     add_stdout = False
    335 
    336     formatter = logging.Formatter("%(asctime)s %(levelname)-5.5s %(message)s")
    337 
    338     # Write out to a logfile
    339     if options.logfile:
    340         handler = logging.FileHandler(options.logfile, mode="w")
    341         handler.setFormatter(formatter)
    342         handler.setLevel(logging.DEBUG)
    343         root_logger.addHandler(handler)
    344     else:
    345         # The logfile wasn't specified. Add a stdout logger.
    346         add_stdout = True
    347 
    348     if options.verbose:
    349         # Add a stdout logger as well in verbose mode
    350         root_logger.setLevel(logging.DEBUG)
    351         add_stdout = True
    352     else:
    353         root_logger.setLevel(logging.INFO)
    354 
    355     if add_stdout:
    356         stdout_handler = logging.StreamHandler(sys.stdout)
    357         stdout_handler.setFormatter(formatter)
    358         stdout_handler.setLevel(logging.DEBUG)
    359         root_logger.addHandler(stdout_handler)
    360 
    361 
    362 if __name__ == '__main__':
    363     # Get the options from the user.
    364     options = get_options()
    365 
    366     # Setup logging using the user options
    367     setup_logging(options)
    368 
    369     # Run main script.
    370     try:
    371         rc = smbserver(options)
    372     except Exception as e:
    373         log.exception(e)
    374         rc = ScriptRC.EXCEPTION
    375 
    376     log.info("[SMB] Returning %d", rc)
    377     sys.exit(rc)
    378