エンタープライズギークス (Enterprise Geeks)

企業システムの企画・開発に携わる技術者集団のブログです。開発言語やフレームワークなどアプリケーション開発に関する各種情報を発信しています。ウルシステムズのエンジニア有志が運営しています。

ServletFilterによるStruts1脆弱性対策コード

以前の記事で「ServletFilterではマルチパートリクエストに対応できない」と書いたが、Strutsの内部で処理しているマルチパートリクエストのパース処理を実装すれば、ServletFilterだけでも脆弱性問題への対処が可能である。

この方法は、ServletFilterのコード量が多くなるデメリットがあるものの、Struts本体やbeanutilsに手を加えずに対応できる運用上のメリットがある。Strutsのバージョンに依存せず(&同じ脆弱性を持つ他のフレームワークに対しても)同一の方法で対処できるメリットもある。

ここでは、あるクライアントに対して我々が適用した、ServletFilterだけでStruts脆弱性問題に対応する方法を紹介する。

除外するリクエストのルール設定

除外するリクエストのルールは、excludeというキーにカンマ区切りで正規表現を羅列して、web.xml内に設定できるようにした。
以下の設定では、今回発見されたセキュリティホールに加えて、MultipartRequestHandlerServletに対するアクセスも禁止している。

<filter>
    <filter-name>StrutsSecurityFilter</filter-name>
    <filter-class>jp.co.StrutsSecurityFilter</filter-class>
    <init-param>
        <param-name>exclude</param-name>
        <param-value>(^|\W)[cC]lass\W,(^|\W)[Mm]ultipartRequestHandler\W,(^|\W)[Ss]ervlet\W</param-value>
    </init-param>
</filter>

ServletFilter

ServletFilterの処理は次の通り(クラス名はStrutsSecurityFilterとした)。

  1. web.xmlで設定したexcludeパターンをパースして、除外パターンを初期化する。
  2. リクエストパラメーターが除外パターンに該当するかどうかを判定する。
    チェックの結果、不正リクエストと判断した場合は、IllegalArgumentExceptionを投げる。
  3. 次に、Jakarta CommonsのFileUploadを用い、MIMEのマルチパートか否かを判定する(マルチパートでなければチェックは終了)。
  4. マルチパートリクエストだった場合には、リクエストをいったん自作のBufferedRequestWrapperにラッピングする。
    このラッパーは、リクエストデータをメモリー中にバッファリングする(バッファリングする理由は、Strutsに制御を渡す際に、InputStreamを読み込み前の状態に戻す必要があるため)。なお、バッファリングには上限を設け、巨大なサブミットデータにより攻撃を受けた際にヒープが食い尽くされる事が無いように制御する。
  5. Jakarta CommonsのFileUploadを用い、multipart/form-dataをパースする。 パースした後、除外パターンに該当するかを判定し、不正リクエストの場合には、IllegalArgumentExceptionを投げる。
  6. Strutsに制御を渡す。
import java.io.IOException;
import java.util.ArrayList;
import java.util.Enumeration;
import java.util.Iterator;
import java.util.List;
import java.util.regex.Pattern;

import javax.servlet.Filter;
import javax.servlet.FilterChain;
import javax.servlet.FilterConfig;
import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;

import org.apache.commons.fileupload.FileItem;
import org.apache.commons.fileupload.FileUpload;
import org.apache.commons.fileupload.FileUploadBase;
import org.apache.commons.fileupload.FileUploadException;
import org.apache.log4j.Logger;

/**
 * Apache Struts の脆弱性(CVE-2014-0094)に対応するためのフィルター
 */
public class StrutsSecurityFilter implements Filter {

    /** Log4j Logger */
    private static Logger LOG = Logger.getLogger(StrutsSecurityFilter.class);

    /**
     * 除外パターンリスト
     */
    private static List<Pattern> excludePatternList = new ArrayList<Pattern>();

    /**
     * <code>EXCLUDE_PARAM_PATH</code>
     */
    private static final String EXCLUDE_PARAM_PATH = "exclude";

    @SuppressWarnings("unused")
    @Override
    public void init(FilterConfig config) throws ServletException {
        String exclude = config.getInitParameter(EXCLUDE_PARAM_PATH);
        if (exclude == null) {
            return;
        }
        String[] excludes = exclude.split(",");
        for (String s : excludes) {
            Pattern pattern = Pattern.compile(s);
            excludePatternList.add(pattern);
        }
        LOG.info("StrutsSecurityFilter initialize. exclude parameters ["
                + excludePatternList + "]");
    }

    @SuppressWarnings("deprecation")
    @Override
    public void doFilter(ServletRequest request, ServletResponse response,
            FilterChain chain) throws IOException, ServletException {

        HttpServletRequest httpRequest = (HttpServletRequest) request;

        validateParamaeterNames(httpRequest);
        if (!FileUploadBase.isMultipartContent(httpRequest)) {
            chain.doFilter(request, response);
        } else {
            BufferedRequestWrapper bufferedRequest = createRequestWrapper(httpRequest);
            List fileItemList = parseRequest(bufferedRequest);
            FileItem fileItem;
            for (Iterator fileItemListIt = fileItemList.iterator(); fileItemListIt
                    .hasNext(); validateParameter(fileItem.getFieldName())) {
                fileItem = (FileItem) fileItemListIt.next();
            }
            chain.doFilter(bufferedRequest, response);
        }
    }

    @Override
    public void destroy() {
        // noop
    }

    /**
     * HttpRequestWrapperを生成する。
     *
     * @param request
     *            http request
     * @return wrapper
     * @throws ServletException
     *             例外発生時
     * @throws IOException
     *             例外発生時
     */
    protected BufferedRequestWrapper createRequestWrapper(
            HttpServletRequest request) throws ServletException, IOException {
        return new BufferedRequestWrapper(request);
    }

    /**
     * multipart/form-data 形式パラメタをパースする。
     *
     * @param bufferedRequest
     *            request
     * @return パラメタリスト
     * @throws IOException
     *             例外発生時
     */
    @SuppressWarnings("deprecation")
    protected List parseRequest(HttpServletRequestWrapper bufferedRequest)
            throws IOException {
        FileUpload upload = new FileUpload(new DefaultFileItemFactory());
        try {
            return upload.parseRequest(bufferedRequest);
        } catch (FileUploadException e) {
            throw new IllegalStateException("request parse error", e);
        }
    }

    /**
     * httpRequest.getParameterNames()の妥当性を確認する。
     *
     * @param httpRequest
     *            http request
     */
    protected void validateParamaeterNames(HttpServletRequest httpRequest) {
        Enumeration<?> params = httpRequest.getParameterNames();
        while (params.hasMoreElements()) {
            String paramName = (String) params.nextElement();
            validateParameter(paramName);
        }
    }

    /**
     * validate parameter
     *
     * @param target
     *            the target
     */
    protected void validateParameter(String target) {
        if (isAttack(target)) {
            String msg = String.format(
                    "Parameter [%s] is on the excludeParams list of patterns!",
                    target);
            LOG.error(msg);
            throw new IllegalArgumentException(msg);
        }
    }

    /**
     * 攻撃対象文字列かを判定する。
     *
     * @param target
     *            対象文字列
     * @return 攻撃対象文字列の場合はtrue。
     */
    protected boolean isAttack(String target) {
        for (Pattern pattern : excludePatternList) {
            if (pattern.matcher(target).find()) {
                return true;
            }
        }
        return false;
    }
}

マルチパートリクエストをラッピングする自作のBufferedRequestWrapperのコードは次の通り。サブミットデータのサイズに上限は、POST_MAX_SIZE定数に持たせている。

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;

import javax.servlet.ServletInputStream;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;

/**
 * Bufferd request wrapper.
 */
class BufferedRequestWrapper extends HttpServletRequestWrapper {
    /** POSTの最大サイズ */
    private static final int POST_MAX_SIZE = 10485760;

    /**
     * <code>buffer</code>
     */
    private byte[] buffer;

    /**
     * コンストラクタ。
     *
     * @param request
     *            request
     * @throws IOException
     *             例外発生時
     * @throws IllegalArgumentException
     *             POSTサイズの上限に達した時
     */
    public BufferedRequestWrapper(HttpServletRequest request)
            throws IOException {
        super(request);

        InputStream is = request.getInputStream();
        ByteArrayOutputStream baos = new ByteArrayOutputStream();
        byte buff[] = new byte[1024];
        int read;
        while (baos.size() <= POST_MAX_SIZE && (read = is.read(buff)) > 0) {
            baos.write(buff, 0, read);
        }
        if (baos.size() >= POST_MAX_SIZE) {
            throw new IllegalArgumentException("POST_MAX_SIZE is exceeded");
        }
        this.buffer = baos.toByteArray();
    }

    @SuppressWarnings("unused")
    @Override
    public ServletInputStream getInputStream() throws IOException {
        return new BufferedServletInputStream(this.buffer);
    }

    /**
     * InputStream.
     */
    private static class BufferedServletInputStream extends ServletInputStream {

        private ByteArrayInputStream inputStream;

        public BufferedServletInputStream(byte[] buffer) {
            this.inputStream = new ByteArrayInputStream(buffer);
        }

        @SuppressWarnings("unused")
        @Override
        public int available() throws IOException {
            return inputStream.available();
        }

        @Override
        public void close() throws IOException {
            this.inputStream.close();
        }

        @Override
        public void mark(int readAheadLimit) {
            this.inputStream.mark(readAheadLimit);
        }

        @Override
        public boolean markSupported() {
            return this.inputStream.markSupported();
        }

        @SuppressWarnings("unused")
        @Override
        public int read() throws IOException {
            return inputStream.read();
        }

        @Override
        public int read(byte b[]) throws IOException {
            return this.inputStream.read(b);
        }

        @SuppressWarnings("unused")
        @Override
        public int read(byte[] b, int off, int len) throws IOException {
            return inputStream.read(b, off, len);
        }

        @SuppressWarnings("unused")
        @Override
        public void reset() throws IOException {
            this.inputStream.reset();
        }

        @SuppressWarnings("unused")
        @Override
        public long skip(long n) throws IOException {
            return this.inputStream.skip(n);
        }
    }
}

上述のStrutsSecurityFilterでは、以下のようにJakarta CommonsのFileUploadを用いてmultipart/form-dataをパースしている。

FileUpload upload = new FileUpload(new DefaultFileItemFactory());

この処理は、デフォルトの挙動ではディスク上にアップロードされたデータを書き出すが、フィルターはアップロードされたコンテンツの内容は必要としない。
このため、DefaultFileItemFactory, BinaryFileItem, TextFileItem, NullOutputStreamといったクラスを使って、アップロードコンテンツを破棄する。

import org.apache.commons.fileupload.FileItem;
import org.apache.commons.fileupload.FileItemFactory;

/**
 * Default FileItemFactory
 */
public class DefaultFileItemFactory implements FileItemFactory {

    /**
     * コンストラクタ
     */
    public DefaultFileItemFactory() {
    }

    @SuppressWarnings("unused")
    @Override
    public FileItem createItem(final String fieldName,
            final String contentType, final boolean isFormField,
            final String fileName) {
        if (isFormField)
            return new TextFileItem(fieldName);
        else
            return new BinaryFileItem(fieldName);
    }
}
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;

import org.apache.commons.fileupload.FileItem;
import org.apache.commons.fileupload.FileItemHeaders;

/**
 * Binay fileitem.
 */
class BinaryFileItem implements FileItem {

    private static final long serialVersionUID = 0x545d13c2f7d06fd7L;
    private String fieldName;
    private transient NullOutputStream outputStream;

    public BinaryFileItem(String fieldName) {
        this.fieldName = fieldName;
        outputStream = new NullOutputStream();
    }

    @Deprecated
    @Override
    public void delete() {
        // noop
    }

    @Deprecated
    @Override
    public byte[] get() {
        return null;
    }

    @Deprecated
    @Override
    public String getContentType() {
        return null;
    }

    @Deprecated
    @Override
    public String getFieldName() {
        return fieldName;
    }

    @SuppressWarnings("unused")
    @Deprecated
    @Override
    public InputStream getInputStream() throws IOException {
        return null;
    }

    @Deprecated
    @Override
    public String getName() {
        return null;
    }

    @Deprecated
    @Override
    public OutputStream getOutputStream() {
        return outputStream;
    }

    @Deprecated
    @Override
    public long getSize() {
        return 0L;
    }

    @Deprecated
    @Override
    public String getString() {
        return null;
    }

    @Deprecated
    @Override
    public String getString(@SuppressWarnings("unused") String encoding) {
        return null;
    }

    @Deprecated
    @Override
    public boolean isFormField() {
        return false;
    }

    @Deprecated
    @Override
    public boolean isInMemory() {
        return true;
    }

    @Override
    public void setFieldName(String fieldName) {
        this.fieldName = fieldName;
    }

    @Deprecated
    @Override
    public void setFormField(@SuppressWarnings("unused") boolean flag) {
        // noop
    }

    @Deprecated
    @Override
    public void write(@SuppressWarnings("unused") File file1) {
        // noop
    }

    @Deprecated
    @Override
    public FileItemHeaders getHeaders() {
        return null;
    }

    @Deprecated
    @Override
    public void setHeaders(
            @SuppressWarnings("unused") FileItemHeaders fileitemheaders) {
        // noop
    }

}
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.io.UnsupportedEncodingException;


import org.apache.commons.fileupload.FileItem;
import org.apache.commons.fileupload.FileItemHeaders;

/**
 * Text fileitem.
 */
class TextFileItem implements FileItem {

    private String fieldName;
    private transient NullOutputStream outputStream;

    public TextFileItem(String fieldName) {
        this.fieldName = fieldName;
        outputStream = new NullOutputStream();
    }

    @Deprecated
    @Override
    public void delete() {
        // noop
    }

    @Deprecated
    @Override
    public byte[] get() {
        return null;
    }

    @Override
    public String getContentType() {
        return null;
    }

    @Override
    public String getFieldName() {
        return fieldName;
    }

    @SuppressWarnings("unused")
    @Deprecated
    @Override
    public InputStream getInputStream() throws IOException {
        return null;
    }

    @Override
    public String getName() {
        return null;
    }

    @Override
    public OutputStream getOutputStream() {
        return outputStream;
    }

    @Deprecated
    @Override
    public long getSize() {
        return 0L;
    }

    @Deprecated
    @Override
    public String getString() {
        return null;
    }

    @SuppressWarnings("unused")
    @Deprecated
    @Override
    public String getString(String encoding)
            throws UnsupportedEncodingException {
        return null;
    }

    @Override
    public boolean isFormField() {
        return true;
    }

    @Deprecated
    @Override
    public boolean isInMemory() {
        return true;
    }

    @Override
    public void setFieldName(String fieldName) {
        this.fieldName = fieldName;
    }

    @Deprecated
    @Override
    public void setFormField(@SuppressWarnings("unused") boolean flag) {
        // noop
    }

    @Deprecated
    @Override
    public void write(@SuppressWarnings("unused") File file1) {
        // noop
    }

    @Deprecated
    @Override
    public FileItemHeaders getHeaders() {
        return null;
    }

    @Deprecated
    @Override
    public void setHeaders(
            @SuppressWarnings("unused") FileItemHeaders fileitemheaders) {
        // noop
    }
}
import java.io.IOException;
import java.io.OutputStream;

/**
 * NullOutputStream.
 */
class NullOutputStream extends OutputStream {

    public NullOutputStream() {
    }

    @SuppressWarnings("unused")
    @Override
    public void write(int i) throws IOException {
        // noop
    }

    @SuppressWarnings("unused")
    @Override
    public void write(byte abyte0[], int i, int j) throws IOException {
        // noop
    }
}

[高橋 友樹]