深入学习Mybatis-手写Mybatis框架


1. dom4j解析XML文件

  • 引入parse-xml-by-dom4j
<!--dom4j的依赖-->
      <dependency>
          <groupId>org.dom4j</groupId>
          <artifactId>dom4j</artifactId>
          <version>2.1.3</version>
      </dependency>
      <!--jaxen的依赖-->
      <dependency>
          <groupId>jaxen</groupId>
          <artifactId>jaxen</artifactId>
          <version>1.2.0</version>
      </dependency>

核心解析过程以代码的形式给出

import org.dom4j.Document;
import org.dom4j.DocumentException;
import org.dom4j.Element;
import org.dom4j.Node;
import org.dom4j.io.SAXReader;
import org.junit.Test;

import java.io.InputStream;
import java.util.List;

/**
 * 功能描述
 *
 * @author: 张庭杰
 * @date: 2022年11月15日 23:21
 */
public class Test01 {
    @Test
    public void testXML(){
        //创建SAXReader对象
        SAXReader reader = new SAXReader();
        //获取输入流
        InputStream inputStream = ClassLoader.getSystemClassLoader().getResourceAsStream("mybatis-config.xml");
        try {
            //读取XML文件,返回document对象,document对象是文档,代表了整个XML文件
            Document document = reader.read(inputStream);
            //这个xpath代表了:从根下开始找configuration标签,然后找configuration标签下的子标签environments
            //xpath是做标签路径匹配的,能让我们快速定位XML
            String xPath = "/configuration/environments";
            //强转后使用更便捷
            Element environments = (Element) document.selectSingleNode(xPath);
            //获取默认环境的id
            String defaultValue = environments.attributeValue("default");
            //然后你要获取这个环境下的配置
            xPath = "/configuration/environments/environment[@id='"+defaultValue+"']";
            //获取孩子节点
            Element environment = (Element) document.selectSingleNode(xPath);
            //1.获取事务管理器,并确定事务管理的类型
            Element transactionManager = environment.element("transactionManager");
            String type = transactionManager.attributeValue("type");
            if(type.equals("JDBC")){
                //获取事务管理器
                System.out.println("JDBC transaction");
            }else{
                System.out.println("MANGED");
            }
            //2.获取数据库连接
            Element dataSource = environment.element("dataSource");
            String dataSourceType = dataSource.attributeValue("type");
            if(dataSourceType.equals("POOLED")){
                System.out.println("数据库连接池");
            }else{
                System.out.println("不是数据库连接池");
            }
            //3.JDBC连接参数
            List<Element> properties = dataSource.elements();
            System.out.println("JDBC连接参数");
            for (Element property : properties) {
                System.out.println(property.attributeValue("name")+","+property.attributeValue("value"));
            }
            //4.获取mapper标签,从任一位置开始,获取所有的某个标签
            xPath = "//mapper";
            List<Node> mappers = document.selectNodes(xPath);
            System.out.println("mappers:");
            mappers.forEach(mapper ->{
                Element mapperElement = (Element) mapper;
                String resources = mapperElement.attributeValue("resource");
                System.out.println(resources);
            });
        } catch (DocumentException e) {
            e.printStackTrace();
        }
    }

    @Test
    public void testMapperXml(){
        SAXReader saxReader = new SAXReader();
        InputStream input = ClassLoader.getSystemClassLoader().getResourceAsStream("mapper/StudentMapper.xml");
        try {
            Document document = saxReader.read(input);
            //获取namespace
            Element mapper = document.getRootElement();
            //获取命名空间
            String xPath = "/mapper";
            Element mapperElement = (Element) document.selectSingleNode(xPath);
            System.out.println("命名空间"+mapperElement.attributeValue("namespace"));
            //获取所有的子节点
            List<Element> elements = mapperElement.elements();
            elements.forEach(element -> {
                //获取sqlId
                String id = element.attributeValue("id");
                System.out.print("sqlId:"+id);
                //获取resultType
                String resultType = element.attributeValue("resultType");
                if (resultType!=null) {
                    System.out.print(",数据返回类型是:"+resultType);
                }
                //获取标签中的文本,并且去除前后空白
                String sql = element.getTextTrim();
                System.out.print(",sql语句是"+sql);
                //转换
                sql = sql.replaceAll("#\\{[0-9A-Za-z_$]*}", "?");//替换为新sql
                System.out.println(",新sql是"+sql);
            });
        } catch (DocumentException e) {
            e.printStackTrace();
        }
    }
}

2. 创建基础类以及基础类编写

  • 编写获取资源文件的工具类
public class Resources {
    private Resources(){}//建议私有化,不需要创建对象就能调用
    //获取资源文件并以流的方式返回
    public static InputStream getResourceAsStream(String resource){
        return ClassLoader.getSystemClassLoader().getResourceAsStream(resource);
    }
}
  • 定义核心对象
public class SqlSessionFactoryBuilder {
    public SqlSessionFactoryBuilder(){}
    public SqlSessionFactory build(InputStream resources){}
}
/**
 * 一个对象对应一个数据库
 * 通过SqlSessionFactory对象可以获取SqlSession对象(开启会话)
 * 一个SqlSessionFactory对象可以开启多个会话。 
 */
public class SqlSessionFactory {}

SqlSessionFactory对象需要什么属性?

  • 1.事务管理器
public interface Transaction {
    /**
     * 提交事务
     */
    void commit();

    /**
     * 回滚事务
     */
    void rollBack();

    /**
     * 关闭事务
     */
    void close();
    void openConnection();
    Connection getConnection();
}
public class JDBCTransaction implements Transaction{
    /**
     * 自动提交标志
     */
    private boolean autoCommit;
    //2.数据库数据源
    private DataSource dataSource;

    private Connection connection;//连接对象

    //创建事务管理器对象
    public JDBCTransaction(boolean autoCommit, DataSource dataSource) {
        this.autoCommit = autoCommit;
        this.dataSource = dataSource;
    }

    @Override
    public void commit() {
        //控制事务的时候实际上还是使用connection对象
        //因此做到这一步还要做一个获取连接对象的操作
        //因此到这里先打住,先去做数据源的实现
        try {
            connection.commit();
        } catch (SQLException e) {
            e.printStackTrace();
        }

    }

    @Override
    public void rollBack() {
        try {
            connection.rollback();
        } catch (SQLException e) {
            e.printStackTrace();
        }
    }

    @Override
    public void close() {
        try {
            connection.close();
        } catch (SQLException e) {
            e.printStackTrace();
        }
    }

    @Override
    public void openConnection(){
        if (connection == null) {
            try {
                connection = dataSource.getConnection();
            } catch (SQLException e) {
                e.printStackTrace();
            }
        }
    }

    @Override
    public Connection getConnection() {
        return connection;
    }


}
  • 2.数据库数据源

做到这一步,我们发现其实factory对象内部其实是不需要数据源对象的,这是因为针对于数据源连接的操作都是针对事务进行操作的,因此我们只需要在事务管理器中导入数据源对象即可。

而数据源的实现是需要遵循一套规范的DataSource,只要符合这套规范就可以提供数据源。

所以到现在,我们要做的就是实现它的子类。

public class UnPooledDataSource implements DataSource {
    //这些值将从配置文件中读取
    private String driver;
    private String url;
    private String username;
    private String password;
    public UnPooledDataSource(String driver, String url, String username, String password) {
        try {
            Class.forName(driver);
        } catch (ClassNotFoundException e) {
            e.printStackTrace();
        }
        this.url = url;
        this.username = username;
        this.password = password;
    }

    @Override
    public Connection getConnection() throws SQLException {
        return DriverManager.getConnection(url, password, username);
    }
  • 3.mapper集合:这个可以如何处理?
    • 通过观察sql语句的定义可以用id=>sql相关信息这样的映射
    • 为什么不直接是sql语句?这是因为还有一些诸如resultType这样的配置信息需要用户进行配置,框架进行解析
    • 因此,还需要将相关信息封装到标签对象中

因此封装了sql标签的对象是这样定义的

public class MapperStatement {
    private String sql;
    private String resultType;
}

3. 解析xml文档并构造对象

这一步其实就是在构造一个SqlSessionFactory对象,这个读取了必要的配置信息,并且根据这些必要信息生成了必要的重量级对象。

首先我们要获取一个事务管理器,而事务管理器是依赖于数据源对象,因此最先应该先构造数据源对象,而mapper文件是独立的,因此你放在那一步开始做都可以,这里的话给出代码

private DataSource getDataSource(Element dataSourceElement){
        String type = dataSourceElement.attributeValue("type").trim().toUpperCase(Locale.ROOT);
        //获取相关参数
        Map<String,String> map = new HashMap<>();
        List<Element> property = dataSourceElement.elements("property");
        property.forEach(element -> {
            String name = element.attributeValue("name");
            String value = element.attributeValue("value");
            map.put(name,value);

        });

        DataSource dataSource = null;
        if(type.equals(Const.UN_POOLED_DATASOURCE)){
            dataSource = new UnPooledDataSource(map.get("driver"),map.get("url"),map.get("username"),map.get("password"));
        }else if(type.equals(Const.POOLED_DATASOURCE)){
            dataSource = new PooledDataSource();//未实现
        }else if(type.equals(Const.JNDI_DATASOURCE)){
            dataSource = new JNDIDataSource();//未实现
        }
        return dataSource;
    }

    private Transaction getTransaction(Element transactionElement,DataSource dataSource){
        Transaction transaction = null;
        String type = transactionElement.attributeValue("type").trim().toUpperCase(Locale.ROOT);
        if(type.equals(Const.JDBC_TRANSACTION)){
            //默认开启事务
            transaction = new JDBCTransaction(false,dataSource);
        }else if(type.equals(Const.MANAGED_TRANSACTION)){
            transaction = new ManagedTransaction();
        }
        return transaction;
    }

然后是构造mapper的对象

private Map<String, MapperStatement> getMapper(List<String>sqlMapperXMLPathList){
        //解析所有的XML文件
        //这个map能解析namespace的精髓
        Map<String,MapperStatement> map = new HashMap<>();
        sqlMapperXMLPathList.forEach(path->{
            SAXReader reader = new SAXReader();
            try {
                Document document = reader.read(Resources.getResourceAsStream(path));
                //拿到一个mapper.xml文件,然后开始解析里面的内容
                //1.拿id,以及其余元素
                Element mapper = (Element) document.selectSingleNode("/mapper");
                String namespace = mapper.attributeValue("namespace");
                //2.拿Sql
                List<Element> elements = mapper.elements();
                elements.forEach(element -> {
                    String id = element.attributeValue("id");
                    //获取id,则拼接出sqlId
                    String sqlId = namespace+"."+id;
                    String resultType = element.attributeValue("resultType");
                    String sql = element.getTextTrim();
                    //获取解析后的sql
                    sql = sql.replaceAll("#\\{[0-9A-Za-z_$]*}","?");
                    MapperStatement mapperStatement = new MapperStatement(sql,resultType);
                    map.put(sqlId,mapperStatement);
                });
            } catch (DocumentException e) {
                e.printStackTrace();
            }
        });
        return map;
    }

核心解析流程代码

public SqlSessionFactory build(InputStream in){
    //解析输入流(document对象)
    //将document对象里面的内容全部提取出来,封装为SqlSessionFactory对象即可
    //建造者模式
    //1.解析XML文件,获取transaction以及mapper集合对象
    SqlSessionFactory factory = null;
    try {
        //思路,首先解析出数据源必要参数构造数据源
        //根据数据源构造事务管理器
        SAXReader saxReader = new SAXReader();
        Document document = saxReader.read(in);
        Element environments = (Element) document.selectSingleNode("/configuration/environments");
        String defaultId = environments.attributeValue("default");
        Element environment = (Element) document.selectSingleNode("/configuration/environments/environment[@id='"+defaultId+"']");
        Element transactionManagerElement = environment.element("transactionManager");
        Element dataSourceElement = environment.element("dataSource");
        //1.获取数据源对象
        DataSource dataSource = getDataSource(dataSourceElement);
        //2.依据数据源对象获取事务管理器
        Transaction transaction = getTransaction(transactionManagerElement,dataSource);
        //3.解析sql集合
        List<String> sqlMapperXMLPathLists = new ArrayList<>();
        List<Node> mappers = document.selectNodes("//mapper");
        mappers.forEach(mapper->{
            //获取路径集合
            sqlMapperXMLPathLists.add(((Element) mapper).attributeValue("resource"));
        });
        Map<String, MapperStatement> mapper = getMapper(sqlMapperXMLPathLists);
        factory = new SqlSessionFactory(transaction,mapper);
    } catch (Exception e) {
        e.printStackTrace();
    }
    return factory;
}

4. 封装SqlSession对象

SqlSession是什么?

SqlSession是一个用来描述与数据库连接过程的对象,当开启会话之后就能够通过此对象进行CRUD。而在Mybatis中,这个对象是通过SqlSessionFactory执行openSession()来拿到的

那么在执行sql的过程中,需要两个元素

  • 负责管理事务,负责向数据库提交数据的事务管理器
  • 负责获取用户编码的sql,通过id找到sql的map集合对象

而我们发现这两个元素正是SqlSessionFactory中含有的私有对象,那么在构造的时候,我们可以将当前构造完事务管理器的工厂对象作为组件传入SqlSession中,来构造这个会话对象

public class SqlSession {
    private SqlSessionFactory sqlSessionFactory;
    public SqlSession(SqlSessionFactory sqlSessionFactory){
        this.sqlSessionFactory = sqlSessionFactory;
    }
    public void commit(){
        sqlSessionFactory.getTransaction().commit();
    }

    public void close(){
        sqlSessionFactory.getTransaction().close();
    }

    public void rollback(){
        sqlSessionFactory.getTransaction().rollBack();
    }
}

5. 编写执行sql的方法

接着就是编写相关的sql方法,来执行这些方法了,我们先写出一个基本框架

public int insert(String sqlId,Object pojo){
    //封装JDBC代码
    int count = 0;
    Connection connection = sqlSessionFactory.getTransaction().getConnection();
    //原始sql,含有#{}等符号,这时候还需要提供一个工具类进行sql的解析,将相关的#{}变成?
    //对于指定了传参的参数,还需要通过反射等机制进行属性的一一映射
    String sql = sqlSessionFactory.getMapperStatementMap().get(sqlId).getSql();
    sql = sql.replaceAll("#\\{[a-zA-Z0-9_$]*}","?");
    try {
        PreparedStatement preparedStatement = connection.prepareStatement(sql);//编译sql
        //传值后执行
        //...先省下
        count = preparedStatement.executeUpdate();
    } catch (SQLException e) {
        e.printStackTrace();
    }
    //给占位符传值
    return count;
}

好了,接下来的问题就是,我们如何对占位符进行传值?

  • 需要知道有多少个问号:一个#对应一个?
  • 需要知道哪个数据给到哪个数据里面去:

为了简化处理,我们将数据类型转化为setString可以处理的类型,简化处理逻辑,理解核心思路

private Map<Integer,String> parseSql(String sql){
    int length = sql.length();
    Map<Integer, String> map = new HashMap<>();
    int i = 0;
    int leftIndex = 0;
    int rightIndex = 0;
    int index = 0;
    while (true){
        leftIndex = sql.indexOf("#",i);
        if(leftIndex <0 ){
            break;
        }
        index++;
        rightIndex = sql.indexOf("}",i);
        String propertyName = sql.substring(leftIndex + 2, rightIndex).trim();
        i = rightIndex+1;
        map.put(index,propertyName);
    }
    return map;
}
PreparedStatement preparedStatement = connection.prepareStatement(sql);//编译sql
//传值后执行
//然后进行赋值
for (Map.Entry<Integer, String> property : properties.entrySet()) {
    //获取属性名
    String value = property.getValue();
    Integer location = property.getKey();
    //通过反射获取object中的值
    String methodName ="get"+value.toUpperCase(Locale.ROOT).charAt(0) + value.substring(1);
    Method method = pojo.getClass().getDeclaredMethod(methodName);
    Object propertyValue = method.invoke(pojo);
    preparedStatement.setString(location,propertyValue.toString());
}

6. 实现selectOne

public Object selectOne(String sqlId,Object param){
    //执行查询语句,返回一个对象
    Object result = null;
    Connection connection = sqlSessionFactory.getTransaction().getConnection();
    MapperStatement mapperStatement = sqlSessionFactory.getMapperStatementMap().get(sqlId);
    //要封装的结果类型
    String sql = mapperStatement.getSql();
    sql = sql.replaceAll("#\\{[a-zA-Z0-9_$]*}","?");
    String resultType = mapperStatement.getResultType();
    try {
        PreparedStatement preparedStatement = connection.prepareStatement(sql);
        preparedStatement.setString(1,param.toString());
        ResultSet resultSet = preparedStatement.executeQuery();
        //查询返回结果集
        if(resultSet.next()){
            Class<?> clazz = Class.forName(resultType);
            result = clazz.newInstance();
            //通过反射,赋值
            ResultSetMetaData metaData = resultSet.getMetaData();
            int columnCount = metaData.getColumnCount();
            for (int i = 1; i <= columnCount; i++) {
                String columnName = metaData.getColumnName(i);
                //String getMethodName = "get"+columnName.toUpperCase(Locale.ROOT).charAt(0)+columnName.substring(1);
                String setMethodName = "set"+columnName.toUpperCase(Locale.ROOT).charAt(0)+columnName.substring(1);
                Method method = clazz.getDeclaredMethod(setMethodName, String.class);
                method.invoke(result,resultSet.getString(columnName));
            }
        }
    } catch (ClassNotFoundException e) {
        e.printStackTrace();
    } catch (SQLException e) {
        e.printStackTrace();
    } catch (InstantiationException e) {
        e.printStackTrace();
    } catch (IllegalAccessException e) {
        e.printStackTrace();
    } catch (NoSuchMethodException e) {
        e.printStackTrace();
    } catch (InvocationTargetException e) {
        e.printStackTrace();
    }
    return result;
}

文章作者: 穿山甲
版权声明: 本博客所有文章除特別声明外,均采用 CC BY 4.0 许可协议。转载请注明来源 穿山甲 !
  目录